From 57a8196c4b52cd0cb5edac093a43161a9d53aa74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristof=20Schr=C3=B6der?= Date: Thu, 7 Dec 2023 10:01:04 +0100 Subject: [PATCH] Add sequential calculator --- src/pydvl/influence/influence_calculator.py | 192 ++++++++++++++++++- src/pydvl/influence/torch/util.py | 47 ++++- tests/influence/test_influence_calculator.py | 51 ++++- 3 files changed, 285 insertions(+), 5 deletions(-) diff --git a/src/pydvl/influence/influence_calculator.py b/src/pydvl/influence/influence_calculator.py index f6f1a176e..1e5b94f08 100644 --- a/src/pydvl/influence/influence_calculator.py +++ b/src/pydvl/influence/influence_calculator.py @@ -1,9 +1,18 @@ from abc import ABC, abstractmethod from math import prod -from typing import Any, Callable, Generic, Optional, Tuple +from typing import ( + Callable, + Generator, + Generic, + Iterable, + List, + Optional, + Tuple, + TypeVar, +) import distributed -import numpy as np +import zarr from dask import array as da from dask import delayed from numpy.typing import NDArray @@ -357,3 +366,182 @@ def _get_client() -> Optional[distributed.Client]: return distributed.get_client() except ValueError: return None + + +class BlockAggregator(Generic[TensorType], ABC): + @abstractmethod + def aggregate_nested( + self, tensors: Generator[Generator[TensorType, None, None], None, None] + ): + """Overwrite this method to aggregate provided blocks into a single tensor""" + + @abstractmethod + def aggregate(self, tensors: Generator[TensorType, None, None]): + """Overwrite this method to aggregate provided list of tensors into a single tensor""" + + +class ListAggregator(BlockAggregator): + def aggregate_nested( + self, tensors: Generator[Generator[TensorType, None, None], None, None] + ): + return [list(tensor_gen) for tensor_gen in tensors] + + def aggregate(self, tensors: Generator[TensorType, None, None]): + return [t for t in tensors] + + +class SequentialInfluenceCalculator: + """ + Simple wrapper class to process batches of data sequentially. Depends on a batch computation model + of type [InfluenceFunctionModel][pydvl.influence.base_influence_model.InfluenceFunctionModel]. + + Args: + influence_function_model: instance of type + [InfluenceFunctionModel][pydvl.influence.base_influence_model.InfluenceFunctionModel], defines the + batch-wise computation model + block_aggregator: optional instance of type [BlockAggregator][pydvl.influence.influence_calculator.BlockAggregator], + used to collect and aggregate the tensors from the sequential process. If None, tensors are collected into + list structures + """ + + def __init__( + self, + influence_function_model: InfluenceFunctionModel, + block_aggregator: Optional[BlockAggregator] = None, + ): + self.block_aggregator = ( + block_aggregator if block_aggregator is not None else ListAggregator() + ) + self.influence_function_model = influence_function_model + + def _influence_factors_gen( + self, data_iterable: Iterable[Tuple[TensorType, TensorType]] + ) -> Generator[TensorType, None, None]: + for x, y in iter(data_iterable): + yield self.influence_function_model.influence_factors(x, y) + + def influence_factors( + self, + data_iterable: Iterable[Tuple[TensorType, TensorType]], + ) -> TensorType: + r""" + Compute the expression + + \[ H^{-1}\nabla_{\theta} \ell(y, f_{\theta}(x)) \] + + where the gradient are computed for the chunks $(x, y)$ of the data_iterable in a sequential manner and + aggregated into a single tensor. + + Args: + data_iterable: + + Returns: + Tensor representing the element-wise inverse Hessian matrix vector + products for the provided batch. + + """ + tensors_gen = self._influence_factors_gen(data_iterable) + t: TensorType = self.block_aggregator.aggregate(tensors_gen) + return t + + def _influences_gen( + self, + test_data_iterable: Iterable[Tuple[TensorType, TensorType]], + train_data_iterable: Iterable[Tuple[TensorType, TensorType]], + influence_type: InfluenceType, + ) -> Generator[Generator[TensorType, None, None], None, None]: + + for x_test, y_test in iter(test_data_iterable): + yield ( + self.influence_function_model.influences( + x_test, y_test, x, y, influence_type + ) + for x, y in iter(train_data_iterable) + ) + + def influences( + self, + test_data_iterable: Iterable[Tuple[TensorType, TensorType]], + train_data_iterable: Iterable[Tuple[TensorType, TensorType]], + influence_type: InfluenceType = InfluenceType.Up, + ) -> TensorType: + r""" + Compute approximation of + + \[ \langle H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}})), \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, f_{\theta}(x_{\text{test}})), + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. The computation is done block-wise for the chunks of the provided + data iterables and aggregated into a single tensor in memory. + + Args: + + test_data_iterable: + train_data_iterable: + influence_type: enum value of [InfluenceType][pydvl.influence.base_influence_model.InfluenceType] + + Returns: + Tensor representing the element-wise scalar products for the provided batch. + + """ + nested_tensor_gen = self._influences_gen( + test_data_iterable, train_data_iterable, influence_type + ) + + t: TensorType = self.block_aggregator.aggregate_nested(nested_tensor_gen) + return t + + def _influences_from_factors_gen( + self, + z_test_factors: Iterable[TensorType], + train_data_iterable: Iterable[Tuple[TensorType, TensorType]], + influence_type: InfluenceType, + ): + + for z_test_factor in iter(z_test_factors): + if isinstance(z_test_factor, list) or isinstance(z_test_factor, tuple): + z_test_factor = z_test_factor[0] + yield ( + self.influence_function_model.influences_from_factors( + z_test_factor, x, y, influence_type + ) + for x, y in iter(train_data_iterable) + ) + + def influences_from_factors( + self, + z_test_factors: Iterable[TensorType], + train_data_iterable: Iterable[Tuple[TensorType, TensorType]], + influence_type: InfluenceType = InfluenceType.Up, + ) -> TensorType: + r""" + Computation of + + \[ \langle z_{\text{test_factors}}, \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle z_{\text{test_factors}}, \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. The gradient is meant to be per sample of the batch $(x, y)$. + + Args: + z_test_factors: pre-computed iterable of tensors, approximating + $H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, f_{\theta}(x_{\text{test}}))$ + train_data_iterable: + influence_type: enum value of [InfluenceType][pydvl.influence.twice_differentiable.InfluenceType] + + Returns: + Tensor representing the element-wise scalar product of the provided batch + + """ + nested_tensor_gen = self._influences_from_factors_gen( + z_test_factors, train_data_iterable, influence_type + ) + t: TensorType = self.block_aggregator.aggregate_nested(nested_tensor_gen) + return t diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index 933c90297..d4f1d90cd 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -1,7 +1,17 @@ import logging import math from functools import partial -from typing import Collection, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from typing import ( + Collection, + Dict, + Generator, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) import dask import numpy as np @@ -11,7 +21,7 @@ from torch.utils.data import Dataset, TensorDataset from pydvl.influence.base_influence_model import TensorType -from pydvl.influence.influence_calculator import NumpyConverter +from pydvl.influence.influence_calculator import BlockAggregator, NumpyConverter logger = logging.getLogger(__name__) @@ -361,3 +371,36 @@ def from_numpy(self, x: NDArray) -> torch.Tensor: if self.device is not None: t = t.to(self.device) return t + + +class TorchCatAggregator(BlockAggregator): + """ + Collect tensors from a generator into a single tensor + """ + + def aggregate_nested( + self, tensors: Generator[Generator[torch.Tensor, None, None], None, None] + ) -> torch.Tensor: + """ + + Args: + tensors: generator providing blocks + + Returns: + A single tensor, which is build from the blocks + + """ + return torch.cat( + list(map(lambda tensor_gen: torch.cat(list(tensor_gen), dim=1), tensors)) + ) + + def aggregate(self, tensors: Generator[torch.Tensor, None, None]) -> torch.Tensor: + """ + Collect tensors from a single level generator into a single tensor + Args: + tensors: generator providing blocks + + Returns: + + """ + return torch.cat(list(tensors)) diff --git a/tests/influence/test_influence_calculator.py b/tests/influence/test_influence_calculator.py index a388de01b..ea01016ae 100644 --- a/tests/influence/test_influence_calculator.py +++ b/tests/influence/test_influence_calculator.py @@ -9,10 +9,11 @@ from pydvl.influence.base_influence_model import UnSupportedInfluenceTypeException from pydvl.influence.influence_calculator import ( DimensionChunksException, + SequentialInfluenceCalculator, UnalignedChunksException, ) from pydvl.influence.torch import ArnoldiInfluence, BatchCgInfluence, DirectInfluence -from pydvl.influence.torch.util import TorchNumpyConverter +from pydvl.influence.torch.util import TorchCatAggregator, TorchNumpyConverter from tests.influence.torch.test_influence_model import model_and_data, test_case @@ -222,3 +223,51 @@ def test_dask_influence_nn(model_and_data, test_case): dask_influence.influence_factors( da_x_test_unaligned_chunks, da_y_test_unaligned_chunks ) + + +def test_sequential_in_memory_calculator(model_and_data, test_case): + model, loss, x_train, y_train, x_test, y_test = model_and_data + train_dataloader = DataLoader( + TensorDataset(x_train, y_train), batch_size=test_case.batch_size + ) + test_dataloader = DataLoader( + TensorDataset(x_test, y_test), batch_size=test_case.batch_size + ) + + inf_model = ArnoldiInfluence( + model, + test_case.loss, + hessian_regularization=test_case.hessian_reg, + ).fit(train_dataloader) + + block_aggregator = TorchCatAggregator() + seq_calculator = SequentialInfluenceCalculator(inf_model, block_aggregator) + + seq_factors = seq_calculator.influence_factors(test_dataloader) + torch_factors = inf_model.influence_factors(x_test, y_test) + + assert torch.allclose(seq_factors, torch_factors, atol=1e-6) + + torch_values_from_factors = inf_model.influences_from_factors( + torch_factors, x_train, y_train, influence_type=test_case.influence_type + ) + + seq_factors_data_loader = DataLoader( + TensorDataset(seq_factors), batch_size=test_case.batch_size + ) + + seq_values_from_factors = seq_calculator.influences_from_factors( + seq_factors_data_loader, + train_dataloader, + influence_type=test_case.influence_type, + ) + + assert torch.allclose(seq_values_from_factors, torch_values_from_factors, atol=1e-6) + + da_values = seq_calculator.influences( + test_dataloader, train_dataloader, influence_type=test_case.influence_type + ) + torch_values = inf_model.influences( + x_test, y_test, x_train, y_train, influence_type=test_case.influence_type + ) + assert torch.allclose(da_values, torch_values, atol=1e-6)