Skip to content

Commit

Permalink
Add sequential calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed Dec 7, 2023
1 parent 0687456 commit 57a8196
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 5 deletions.
192 changes: 190 additions & 2 deletions src/pydvl/influence/influence_calculator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from abc import ABC, abstractmethod
from math import prod
from typing import Any, Callable, Generic, Optional, Tuple
from typing import (
Callable,
Generator,
Generic,
Iterable,
List,
Optional,
Tuple,
TypeVar,
)

import distributed
import numpy as np
import zarr
from dask import array as da
from dask import delayed
from numpy.typing import NDArray
Expand Down Expand Up @@ -357,3 +366,182 @@ def _get_client() -> Optional[distributed.Client]:
return distributed.get_client()
except ValueError:
return None


class BlockAggregator(Generic[TensorType], ABC):
@abstractmethod
def aggregate_nested(
self, tensors: Generator[Generator[TensorType, None, None], None, None]
):
"""Overwrite this method to aggregate provided blocks into a single tensor"""

@abstractmethod
def aggregate(self, tensors: Generator[TensorType, None, None]):
"""Overwrite this method to aggregate provided list of tensors into a single tensor"""


class ListAggregator(BlockAggregator):
def aggregate_nested(
self, tensors: Generator[Generator[TensorType, None, None], None, None]
):
return [list(tensor_gen) for tensor_gen in tensors]

def aggregate(self, tensors: Generator[TensorType, None, None]):
return [t for t in tensors]


class SequentialInfluenceCalculator:
"""
Simple wrapper class to process batches of data sequentially. Depends on a batch computation model
of type [InfluenceFunctionModel][pydvl.influence.base_influence_model.InfluenceFunctionModel].
Args:
influence_function_model: instance of type
[InfluenceFunctionModel][pydvl.influence.base_influence_model.InfluenceFunctionModel], defines the
batch-wise computation model
block_aggregator: optional instance of type [BlockAggregator][pydvl.influence.influence_calculator.BlockAggregator],
used to collect and aggregate the tensors from the sequential process. If None, tensors are collected into
list structures
"""

def __init__(
self,
influence_function_model: InfluenceFunctionModel,
block_aggregator: Optional[BlockAggregator] = None,
):
self.block_aggregator = (
block_aggregator if block_aggregator is not None else ListAggregator()
)
self.influence_function_model = influence_function_model

def _influence_factors_gen(
self, data_iterable: Iterable[Tuple[TensorType, TensorType]]
) -> Generator[TensorType, None, None]:
for x, y in iter(data_iterable):
yield self.influence_function_model.influence_factors(x, y)

def influence_factors(
self,
data_iterable: Iterable[Tuple[TensorType, TensorType]],
) -> TensorType:
r"""
Compute the expression
\[ H^{-1}\nabla_{\theta} \ell(y, f_{\theta}(x)) \]
where the gradient are computed for the chunks $(x, y)$ of the data_iterable in a sequential manner and
aggregated into a single tensor.
Args:
data_iterable:
Returns:
Tensor representing the element-wise inverse Hessian matrix vector
products for the provided batch.
"""
tensors_gen = self._influence_factors_gen(data_iterable)
t: TensorType = self.block_aggregator.aggregate(tensors_gen)
return t

def _influences_gen(
self,
test_data_iterable: Iterable[Tuple[TensorType, TensorType]],
train_data_iterable: Iterable[Tuple[TensorType, TensorType]],
influence_type: InfluenceType,
) -> Generator[Generator[TensorType, None, None], None, None]:

for x_test, y_test in iter(test_data_iterable):
yield (
self.influence_function_model.influences(
x_test, y_test, x, y, influence_type
)
for x, y in iter(train_data_iterable)
)

def influences(
self,
test_data_iterable: Iterable[Tuple[TensorType, TensorType]],
train_data_iterable: Iterable[Tuple[TensorType, TensorType]],
influence_type: InfluenceType = InfluenceType.Up,
) -> TensorType:
r"""
Compute approximation of
\[ \langle H^{-1}\nabla_{\theta} \ell(y_{\text{test}},
f_{\theta}(x_{\text{test}})), \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \]
for the case of up-weighting influence, resp.
\[ \langle H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, f_{\theta}(x_{\text{test}})),
\nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \]
for the perturbation type influence case. The computation is done block-wise for the chunks of the provided
data iterables and aggregated into a single tensor in memory.
Args:
test_data_iterable:
train_data_iterable:
influence_type: enum value of [InfluenceType][pydvl.influence.base_influence_model.InfluenceType]
Returns:
Tensor representing the element-wise scalar products for the provided batch.
"""
nested_tensor_gen = self._influences_gen(
test_data_iterable, train_data_iterable, influence_type
)

t: TensorType = self.block_aggregator.aggregate_nested(nested_tensor_gen)
return t

def _influences_from_factors_gen(
self,
z_test_factors: Iterable[TensorType],
train_data_iterable: Iterable[Tuple[TensorType, TensorType]],
influence_type: InfluenceType,
):

for z_test_factor in iter(z_test_factors):
if isinstance(z_test_factor, list) or isinstance(z_test_factor, tuple):
z_test_factor = z_test_factor[0]
yield (
self.influence_function_model.influences_from_factors(
z_test_factor, x, y, influence_type
)
for x, y in iter(train_data_iterable)
)

def influences_from_factors(
self,
z_test_factors: Iterable[TensorType],
train_data_iterable: Iterable[Tuple[TensorType, TensorType]],
influence_type: InfluenceType = InfluenceType.Up,
) -> TensorType:
r"""
Computation of
\[ \langle z_{\text{test_factors}}, \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \]
for the case of up-weighting influence, resp.
\[ \langle z_{\text{test_factors}}, \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \]
for the perturbation type influence case. The gradient is meant to be per sample of the batch $(x, y)$.
Args:
z_test_factors: pre-computed iterable of tensors, approximating
$H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, f_{\theta}(x_{\text{test}}))$
train_data_iterable:
influence_type: enum value of [InfluenceType][pydvl.influence.twice_differentiable.InfluenceType]
Returns:
Tensor representing the element-wise scalar product of the provided batch
"""
nested_tensor_gen = self._influences_from_factors_gen(
z_test_factors, train_data_iterable, influence_type
)
t: TensorType = self.block_aggregator.aggregate_nested(nested_tensor_gen)
return t
47 changes: 45 additions & 2 deletions src/pydvl/influence/torch/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import logging
import math
from functools import partial
from typing import Collection, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from typing import (
Collection,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)

import dask
import numpy as np
Expand All @@ -11,7 +21,7 @@
from torch.utils.data import Dataset, TensorDataset

from pydvl.influence.base_influence_model import TensorType
from pydvl.influence.influence_calculator import NumpyConverter
from pydvl.influence.influence_calculator import BlockAggregator, NumpyConverter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -361,3 +371,36 @@ def from_numpy(self, x: NDArray) -> torch.Tensor:
if self.device is not None:
t = t.to(self.device)
return t


class TorchCatAggregator(BlockAggregator):
"""
Collect tensors from a generator into a single tensor
"""

def aggregate_nested(
self, tensors: Generator[Generator[torch.Tensor, None, None], None, None]
) -> torch.Tensor:
"""
Args:
tensors: generator providing blocks
Returns:
A single tensor, which is build from the blocks
"""
return torch.cat(
list(map(lambda tensor_gen: torch.cat(list(tensor_gen), dim=1), tensors))
)

def aggregate(self, tensors: Generator[torch.Tensor, None, None]) -> torch.Tensor:
"""
Collect tensors from a single level generator into a single tensor
Args:
tensors: generator providing blocks
Returns:
"""
return torch.cat(list(tensors))
51 changes: 50 additions & 1 deletion tests/influence/test_influence_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from pydvl.influence.base_influence_model import UnSupportedInfluenceTypeException
from pydvl.influence.influence_calculator import (
DimensionChunksException,
SequentialInfluenceCalculator,
UnalignedChunksException,
)
from pydvl.influence.torch import ArnoldiInfluence, BatchCgInfluence, DirectInfluence
from pydvl.influence.torch.util import TorchNumpyConverter
from pydvl.influence.torch.util import TorchCatAggregator, TorchNumpyConverter
from tests.influence.torch.test_influence_model import model_and_data, test_case


Expand Down Expand Up @@ -222,3 +223,51 @@ def test_dask_influence_nn(model_and_data, test_case):
dask_influence.influence_factors(
da_x_test_unaligned_chunks, da_y_test_unaligned_chunks
)


def test_sequential_in_memory_calculator(model_and_data, test_case):
model, loss, x_train, y_train, x_test, y_test = model_and_data
train_dataloader = DataLoader(
TensorDataset(x_train, y_train), batch_size=test_case.batch_size
)
test_dataloader = DataLoader(
TensorDataset(x_test, y_test), batch_size=test_case.batch_size
)

inf_model = ArnoldiInfluence(
model,
test_case.loss,
hessian_regularization=test_case.hessian_reg,
).fit(train_dataloader)

block_aggregator = TorchCatAggregator()
seq_calculator = SequentialInfluenceCalculator(inf_model, block_aggregator)

seq_factors = seq_calculator.influence_factors(test_dataloader)
torch_factors = inf_model.influence_factors(x_test, y_test)

assert torch.allclose(seq_factors, torch_factors, atol=1e-6)

torch_values_from_factors = inf_model.influences_from_factors(
torch_factors, x_train, y_train, influence_type=test_case.influence_type
)

seq_factors_data_loader = DataLoader(
TensorDataset(seq_factors), batch_size=test_case.batch_size
)

seq_values_from_factors = seq_calculator.influences_from_factors(
seq_factors_data_loader,
train_dataloader,
influence_type=test_case.influence_type,
)

assert torch.allclose(seq_values_from_factors, torch_values_from_factors, atol=1e-6)

da_values = seq_calculator.influences(
test_dataloader, train_dataloader, influence_type=test_case.influence_type
)
torch_values = inf_model.influences(
x_test, y_test, x_train, y_train, influence_type=test_case.influence_type
)
assert torch.allclose(da_values, torch_values, atol=1e-6)

0 comments on commit 57a8196

Please sign in to comment.