diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index 73a0f2933..9bc1dc225 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast import torch from torch.func import functional_call @@ -13,6 +13,7 @@ from ..types import ( Batch, BilinearForm, + BilinearFormType, BlockMapper, GradientProvider, Operator, @@ -47,6 +48,9 @@ class TorchBatch(Batch): x: torch.Tensor y: torch.Tensor + def __iter__(self): + return iter((self.x, self.y)) + def __post_init__(self): if self.x.shape[0] != self.y.shape[0]: raise ValueError( @@ -310,7 +314,7 @@ class OperatorBilinearForm( def __init__( self, - operator: "TorchOperator", + operator: "TensorOperator", ): self.operator = operator @@ -345,7 +349,128 @@ def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tenso return torch.einsum("ia,j...a->ij...", left_result, right) -class TorchOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): +class DictBilinearForm(OperatorBilinearForm): + r""" + Base class for bilinear forms based on an instance of + [TorchOperator][pydvl.influence.torch.operator.base.TorchOperator]. This means it + computes weighted inner products of the form: + + $$ \langle \operatorname{Op}(x), y \rangle $$ + + """ + + def __init__( + self, + operator: "TensorDictOperator", + ): + super().__init__(operator) + + def grads_inner_prod( + self, + left: TorchBatch, + right: Optional[TorchBatch], + gradient_provider: TorchGradientProvider, + ) -> torch.Tensor: + r""" + Computes the gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot, \cdot)$ is represented by the + `gradient_provider` and the expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation, + optional; if not provided, the inner product will use the gradient + computed for `left` for both arguments. + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the per-sample gradients + """ + operator = cast(TensorDictOperator, self.operator) + left_grads = gradient_provider.grads(left) + if right is None: + right_grads = left_grads + else: + right_grads = gradient_provider.grads(right) + + left_batch_size, right_batch_size = next( + ( + (l.shape[0], r.shape[0]) + for r, l in zip(left_grads.values(), right_grads.values()) + ) + ) + + if left_batch_size <= right_batch_size: + left_grads = operator.apply_to_mat_dict(left_grads) + tensor_pairs = zip(left_grads.values(), right_grads.values()) + else: + right_grads = operator.apply_to_mat_dict(right_grads) + tensor_pairs = zip(left_grads.values(), right_grads.values()) + + tensors_to_reduce = ( + self._aggregate_grads(left, right) for left, right in tensor_pairs + ) + + return cast(torch.Tensor, sum(tensors_to_reduce)) + + def mixed_grads_inner_prod( + self, + left: TorchBatch, + right: TorchBatch, + gradient_provider: TorchGradientProvider, + ) -> torch.Tensor: + r""" + Computes the mixed gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y}) + \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot)$ and + $\nabla_{\omega}\nabla_{x}\ell(\omega, \cdot)$ are represented by the + `gradient_provider`. The expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the mixed per-sample gradients + """ + operator = cast(TensorDictOperator, self.operator) + right_grads = gradient_provider.mixed_grads(right) + left_grads = gradient_provider.grads(left) + left_grads = operator.apply_to_mat_dict(left_grads) + left_grads_views = (t.reshape(t.shape[0], -1) for t in left_grads.values()) + right_grads_views = ( + t.reshape(*right.x.shape, -1) for t in right_grads.values() + ) + tensor_pairs = zip(left_grads_views, right_grads_views) + tensors_to_reduce = ( + self._aggregate_mixed_grads(left, right) for left, right in tensor_pairs + ) + return cast(torch.Tensor, sum(tensors_to_reduce)) + + @staticmethod + def _aggregate_mixed_grads(left: torch.Tensor, right: torch.Tensor): + return torch.einsum("ik, j...k -> ij...", left, right) + + @staticmethod + def _aggregate_grads(left: torch.Tensor, right: torch.Tensor): + return torch.einsum("i..., j... -> ij", left, right) + + +OperatorBilinearFormType = TypeVar( + "OperatorBilinearFormType", bound=OperatorBilinearForm +) + + +class TensorOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): """ Abstract base class for operators that can be applied to instances of [torch.Tensor][torch.Tensor]. @@ -369,13 +494,6 @@ def to(self, device: torch.device): def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: pass - def as_bilinear_form(self): - """ - Represent this operator as a - [OperatorBilinearForm][pydvl.influence.torch.base.OperatorBilinearForm]. - """ - return OperatorBilinearForm(self) - def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: """ Applies the operator to a single vector. @@ -403,8 +521,72 @@ def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: """ return torch.func.vmap(self.apply_to_vec, in_dims=0, randomness="same")(mat) + def as_bilinear_form(self) -> OperatorBilinearForm: + return OperatorBilinearForm(self) + + +class TensorDictOperator(TensorOperator, ABC): + """ + Abstract base class for operators that can be applied to instances of + [torch.Tensor][torch.Tensor] and compatible dictionaries mapping strings to tensors. + Input dictionaries must conform to the structure defined by the property + `input_dict_structure`. Useful for operators involving autograd functionality + to avoid intermediate flattening and concatenating of gradient inputs. + """ + + def apply_to_mat_dict( + self, mat: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Applies the operator to a dictionary of tensors, compatible to the structure + defined by the property `input_dict_structure`. + + Args: + mat: dictionary of tensors, whose keys and shapes match the property + `input_dict_structure`. + + Returns: + A dictionary of tensors after applying the operator + """ + + if not self._validate_mat_dict(mat): + raise ValueError( + f"Incompatible input structure, expected (excluding batch" + f"dimension): \n {self.input_dict_structure}" + ) + + return self._apply_to_mat_dict(self._dict_to_device(mat)) + + def _dict_to_device(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {k: v.to(self.device) for k, v in mat.items()} + + @property + @abstractmethod + def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: + """ + Implement this to expose the expected structure of the input tensor dict, i.e. + a dictionary of shapes (excluding the first batch dimension), in order + to validate the input tensor dicts. + """ + + @abstractmethod + def _apply_to_mat_dict( + self, mat: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + pass + + def _validate_mat_dict(self, mat: Dict[str, torch.Tensor]) -> bool: + for keys, val in mat.items(): + if val.shape[1:] != self.input_dict_structure[keys]: + return False + else: + return True + + def as_bilinear_form(self) -> DictBilinearForm: + return DictBilinearForm(self) + -TorchOperatorType = TypeVar("TorchOperatorType", bound=TorchOperator) +TorchOperatorType = TypeVar("TorchOperatorType", bound=TensorOperator) class TorchOperatorGradientComposition( diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 4c7075924..30235899c 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -13,8 +13,21 @@ which is useful in the case that keeping $B$ in memory is not feasible. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional, Type, TypeVar, Union +from typing import ( + Callable, + Dict, + Generator, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import torch @@ -24,37 +37,99 @@ TorchBatch, TorchGradientProvider, ) -from .functional import create_batch_hvp_function -from .util import LossType, inverse_rank_one_update, rank_one_mvp +from .functional import create_batch_hvp_function, create_batch_loss_function, hvp +from .util import ( + LossType, + generate_inverse_rank_one_updates, + generate_rank_one_mvp, + inverse_rank_one_update, + rank_one_mvp, +) -class BatchOperation(ABC): +class _ModelBasedBatchOperation(ABC): r""" Abstract base class to implement operations of the form - $$ m(b) \cdot v $$ + $$ m(\text{model}, b) \cdot v $$ + + where model is a [torch.nn.Module][torch.nn.Module]. - where $m(b)$ is a matrix defined by the data in the batch and $v$ is a vector - or matrix. """ + def __init__( + self, + model: torch.nn.Module, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + if restrict_to is None: + restrict_to = { + k: p.detach() for k, p in model.named_parameters() if p.requires_grad + } + self.params_to_restrict_to = restrict_to + self.model = model + @property - @abstractmethod - def input_size(self): - pass + def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: + return {k: p.shape for k, p in self.params_to_restrict_to.items()} @property - @abstractmethod def device(self): - pass + return next(self.model.parameters()).device @property - @abstractmethod def dtype(self): - pass + return next(self.model.parameters()).dtype + + @property + def input_size(self): + return sum(p.numel() for p in self.params_to_restrict_to.values()) - @abstractmethod def to(self, device: torch.device): + self.model = self.model.to(device) + self.params_to_restrict_to = { + k: p.detach() + for k, p in self.model.named_parameters() + if k in self.params_to_restrict_to + } + return self + + def apply_to_tensor_dict( + self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + + if mat_dict.keys() != self.params_to_restrict_to.keys(): + raise ValueError( + "The keys of the matrix dictionary must match the keys of the " + "parameters to restrict to." + ) + + return self._apply_to_tensor_dict( + batch, {k: v.to(self.device) for k, v in mat_dict.items()} + ) + + def _has_batch_dim(self, tensor_dict: Dict[str, torch.Tensor]): + batch_dim_flags = [ + tensor_dict[key].shape == val.shape + for key, val in self.params_to_restrict_to.items() + ] + if len(set(batch_dim_flags)) == 2: + raise ValueError("Existence of batch dim must be consistent") + return not all(batch_dim_flags) + + def _add_batch_dim(self, vec_dict: Dict[str, torch.Tensor]): + result = {} + for key, value in self.params_to_restrict_to.items(): + if value.shape == vec_dict[key].shape: + result[key] = vec_dict[key].unsqueeze(0) + else: + result[key] = vec_dict[key] + return result + + @abstractmethod + def _apply_to_tensor_dict( + self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: pass @abstractmethod @@ -95,51 +170,7 @@ def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: )(batch.x, batch.y, mat) -class ModelBasedBatchOperation(BatchOperation, ABC): - r""" - Abstract base class to implement operations of the form - - $$ m(\text{model}, b) \cdot v $$ - - where model is a [torch.nn.Module][torch.nn.Module]. - - """ - - def __init__( - self, - model: torch.nn.Module, - restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, - ): - if restrict_to is None: - restrict_to = { - k: p.detach() for k, p in model.named_parameters() if p.requires_grad - } - self.params_to_restrict_to = restrict_to - self.model = model - - @property - def device(self): - return next(self.model.parameters()).device - - @property - def dtype(self): - return next(self.model.parameters()).dtype - - @property - def input_size(self): - return sum(p.numel() for p in self.params_to_restrict_to.values()) - - def to(self, device: torch.device): - self.model = self.model.to(device) - self.params_to_restrict_to = { - k: p.detach() - for k, p in self.model.named_parameters() - if k in self.params_to_restrict_to - } - return self - - -class HessianBatchOperation(ModelBasedBatchOperation): +class HessianBatchOperation(_ModelBasedBatchOperation): r""" Given a model and loss function computes the Hessian vector or matrix product with respect to the model parameters, i.e. @@ -173,12 +204,39 @@ def __init__( self._batch_hvp = create_batch_hvp_function( model, loss, reverse_only=reverse_only ) + self.loss = loss + self.reverse_only = reverse_only def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: return self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec) + def _apply_to_tensor_dict( + self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + + func = self._create_seq_func(*batch) + + if self._has_batch_dim(mat_dict): + func = torch.func.vmap( + func, in_dims=tuple((0 for _ in self.params_to_restrict_to)) + ) + + result: Dict[str, torch.Tensor] = func(*mat_dict.values()) + return result + + def _create_seq_func(self, x: torch.Tensor, y: torch.Tensor): + def seq_func(*vec: torch.Tensor) -> Dict[str, torch.Tensor]: + return hvp( + lambda p: create_batch_loss_function(self.model, self.loss)(p, x, y), + self.params_to_restrict_to, + dict(zip(self.params_to_restrict_to.keys(), vec)), + reverse_only=self.reverse_only, + ) + + return seq_func + -class GaussNewtonBatchOperation(ModelBasedBatchOperation): +class GaussNewtonBatchOperation(_ModelBasedBatchOperation): r""" Given a model and loss function computes the Gauss-Newton vector or matrix product with respect to the model parameters, i.e. @@ -222,10 +280,18 @@ def __init__( model, loss, self.params_to_restrict_to ) + def _apply_to_tensor_dict( + self, batch: TorchBatch, vec_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + vec_values = list(self._add_batch_dim(vec_dict).values()) + grads_dict = self.gradient_provider.grads(batch) + grads_values = list(self._add_batch_dim(grads_dict).values()) + gen_result = generate_rank_one_mvp(grads_values, vec_values) + return dict(zip(vec_dict.keys(), gen_result)) + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: flat_grads = self.gradient_provider.flat_grads(batch) - result = rank_one_mvp(flat_grads, vec) - return result + return rank_one_mvp(flat_grads, vec) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: """ @@ -248,7 +314,7 @@ def to(self, device: torch.device): return super().to(device) -class InverseHarmonicMeanBatchOperation(ModelBasedBatchOperation): +class InverseHarmonicMeanBatchOperation(_ModelBasedBatchOperation): r""" Given a model and loss function computes an approximation of the inverse Gauss-Newton vector or matrix product. Viewing the damped Gauss-newton matrix @@ -327,10 +393,11 @@ def regularization(self, value: float): def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: grads = self.gradient_provider.flat_grads(batch) - return ( - inverse_rank_one_update(grads, vec, self.regularization) - / self.regularization - ) + if vec.ndim == 1: + input_vec = vec.unsqueeze(0) + else: + input_vec = vec + return inverse_rank_one_update(grads, input_vec, self.regularization) def apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: """ @@ -353,5 +420,111 @@ def to(self, device: torch.device): self.gradient_provider.params_to_restrict_to = self.params_to_restrict_to return self + def _apply_to_tensor_dict( + self, batch: TorchBatch, vec_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + vec_values = list(self._add_batch_dim(vec_dict).values()) + grads_dict = self.gradient_provider.grads(batch) + grads_values = list(self._add_batch_dim(grads_dict).values()) + gen_result = generate_inverse_rank_one_updates( + grads_values, vec_values, self.regularization + ) + return dict(zip(vec_dict.keys(), gen_result)) + + +BatchOperationType = TypeVar("BatchOperationType", bound=_ModelBasedBatchOperation) + + +class _TensorDictAveraging(ABC): + @abstractmethod + def __call__(self, tensor_dicts: Generator[Dict[str, torch.Tensor], None, None]): + pass + + +_TensorDictAveragingType = TypeVar( + "_TensorDictAveragingType", bound=_TensorDictAveraging +) + + +class _TensorAveraging(Generic[_TensorDictAveragingType], ABC): + @abstractmethod + def __call__(self, tensors: Generator[torch.Tensor, None, None]): + pass + + @abstractmethod + def as_dict_averaging(self) -> _TensorDictAveraging: + pass + + +TensorAveragingType = TypeVar("TensorAveragingType", bound=_TensorAveraging) + + +class _TensorDictChunkAveraging(_TensorDictAveraging): + def __call__(self, tensor_dicts: Generator[Dict[str, torch.Tensor], None, None]): + result = next(tensor_dicts) + n_chunks = 1.0 + for tensor_dict in tensor_dicts: + for key, tensor in tensor_dict.items(): + result[key] += tensor + n_chunks += 1.0 + return {k: t / n_chunks for k, t in result.items()} + + +class ChunkAveraging(_TensorAveraging[_TensorDictChunkAveraging]): + """ + Averages tensors, provided by a generator, and normalizes by the number + of tensors. + """ + + def __call__(self, tensors: Generator[torch.Tensor, None, None]): + result = next(tensors) + n_chunks = 1.0 + for tensor in tensors: + result += tensor + n_chunks += 1.0 + return result / n_chunks + + def as_dict_averaging(self) -> _TensorDictChunkAveraging: + return _TensorDictChunkAveraging() + + +class _TensorDictPointAveraging(_TensorDictAveraging): + def __init__(self, batch_dim: int = 0): + self.batch_dim = batch_dim + + def __call__(self, tensor_dicts: Generator[Dict[str, torch.Tensor], None, None]): + result = next(tensor_dicts) + n_points = next(iter(result.values())).shape[self.batch_dim] + for tensor_dict in tensor_dicts: + n_points_in_batch = next(iter(tensor_dict.values())).shape[self.batch_dim] + for key, tensor in tensor_dict.items(): + result[key] += n_points_in_batch * tensor + n_points += n_points_in_batch + return {k: t / float(n_points) for k, t in result.items()} + + +class PointAveraging(_TensorAveraging[_TensorDictPointAveraging]): + """ + Averages tensors provided by a generator. The averaging is weighted by + the number of points in each tensor and the final result is normalized by the + number of total points. + + Args: + batch_dim: Dimension to extract the number of points for the weighting. + + """ + + def __init__(self, batch_dim: int = 0): + self.batch_dim = batch_dim + + def __call__(self, tensors: Generator[torch.Tensor, None, None]): + result = next(tensors) + n_points = result.shape[self.batch_dim] + for tensor in tensors: + n_points_in_batch = tensor.shape[self.batch_dim] + result += n_points_in_batch * tensor + n_points += n_points_in_batch + return result / float(n_points) -BatchOperationType = TypeVar("BatchOperationType", bound=BatchOperation) + def as_dict_averaging(self) -> _TensorDictPointAveraging: + return _TensorDictPointAveraging(self.batch_dim) diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 81d6b0442..745210f05 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -1,27 +1,30 @@ -from typing import Callable, Dict, Generator, Generic, Optional, Type, Union +from typing import Callable, Dict, Generic, Optional, Tuple, Type, Union import torch from torch import nn as nn from torch.utils.data import DataLoader -from ..array import LazyChunkSequence, SequenceAggregator from .base import ( GradientProviderFactoryType, + TensorDictOperator, TorchAutoGrad, TorchBatch, TorchGradientProvider, - TorchOperator, ) from .batch_operation import ( BatchOperationType, + ChunkAveraging, GaussNewtonBatchOperation, HessianBatchOperation, InverseHarmonicMeanBatchOperation, + PointAveraging, + TensorAveragingType, ) -from .util import TorchChunkAverageAggregator, TorchPointAverageAggregator -class AggregateBatchOperator(TorchOperator, Generic[BatchOperationType]): +class _AveragingBatchOperator( + TensorDictOperator, Generic[BatchOperationType, TensorAveragingType] +): """ Class for aggregating batch operations over a dataset using a provided data loader and aggregator. @@ -32,7 +35,7 @@ class AggregateBatchOperator(TorchOperator, Generic[BatchOperationType]): Attributes: batch_operation: The batch operation to apply. dataloader: The data loader providing batches of data. - aggregator: The sequence aggregator to aggregate the results of the batch + averaging: The sequence aggregator to aggregate the results of the batch operations. """ @@ -40,11 +43,27 @@ def __init__( self, batch_operation: BatchOperationType, dataloader: DataLoader, - aggregator: SequenceAggregator[torch.Tensor], + averager: TensorAveragingType, ): self.batch_operation = batch_operation self.dataloader = dataloader - self.aggregator = aggregator + self.averaging = averager + + @property + def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: + return self.batch_operation.input_dict_structure + + def _apply_to_mat_dict( + self, mat: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + + tensor_dicts = ( + self.batch_operation.apply_to_tensor_dict(TorchBatch(x, y), mat) + for x, y in self.dataloader + ) + dict_averaging = self.averaging.as_dict_averaging() + result: Dict[str, torch.Tensor] = dict_averaging(tensor_dicts) + return result @property def device(self): @@ -85,21 +104,20 @@ def _apply( z: torch.Tensor, batch_ops: Callable[[TorchBatch, torch.Tensor], torch.Tensor], ): - def tensor_gen_factory() -> Generator[torch.Tensor, None, None]: - return ( - batch_ops( - TorchBatch(x.to(self.device), y.to(self.device)), z.to(self.device) - ) - for x, y in self.dataloader - ) - lazy_tensor_sequence = LazyChunkSequence( - tensor_gen_factory, len_generator=len(self.dataloader) + tensors = ( + batch_ops( + TorchBatch(x.to(self.device), y.to(self.device)), z.to(self.device) + ) + for x, y in self.dataloader ) - return self.aggregator(lazy_tensor_sequence) + return self.averaging(tensors) -class GaussNewtonOperator(AggregateBatchOperator[GaussNewtonBatchOperation]): + +class GaussNewtonOperator( + _AveragingBatchOperator[GaussNewtonBatchOperation, PointAveraging] +): r""" Given a model and loss function computes the Gauss-Newton vector or matrix product with respect to the model parameters on a batch, i.e. @@ -142,11 +160,11 @@ def __init__( gradient_provider_factory=gradient_provider_factory, restrict_to=restrict_to, ) - aggregator = TorchPointAverageAggregator() - super().__init__(batch_op, dataloader, aggregator) + averaging = PointAveraging() + super().__init__(batch_op, dataloader, averaging) -class HessianOperator(AggregateBatchOperator[HessianBatchOperation]): +class HessianOperator(_AveragingBatchOperator[HessianBatchOperation, ChunkAveraging]): r""" Given a model and loss function computes the Hessian vector or matrix product with respect to the model parameters for a given batch, i.e. @@ -182,12 +200,12 @@ def __init__( batch_op = HessianBatchOperation( model, loss, restrict_to=restrict_to, reverse_only=reverse_only ) - aggregator = TorchChunkAverageAggregator() - super().__init__(batch_op, dataloader, aggregator) + averaging = ChunkAveraging() + super().__init__(batch_op, dataloader, averaging) class InverseHarmonicMeanOperator( - AggregateBatchOperator[InverseHarmonicMeanBatchOperation] + _AveragingBatchOperator[InverseHarmonicMeanBatchOperation, PointAveraging] ): r""" Given a model and loss function computes an approximation of the inverse @@ -265,8 +283,8 @@ def __init__( gradient_provider_factory=gradient_provider_factory, restrict_to=restrict_to, ) - aggregator = TorchPointAverageAggregator(weighted=False) - super().__init__(batch_op, dataloader, aggregator) + averaging = PointAveraging() + super().__init__(batch_op, dataloader, averaging) @property def regularization(self): diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index a3553af5c..661ace7da 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -11,6 +11,7 @@ Callable, Collection, Dict, + Generator, Iterable, Iterator, List, @@ -48,15 +49,11 @@ "align_with_model", "flatten_dimensions", "TorchNumpyConverter", - "TorchCatAggregator", - "NestedTorchCatAggregator", "torch_dataset_to_dask_array", "EkfacRepresentation", "empirical_cross_entropy_loss_fn", "rank_one_mvp", "inverse_rank_one_update", - "TorchPointAverageAggregator", - "TorchChunkAverageAggregator", "LossType", "ModelParameterDictBuilder", "BlockMode", @@ -451,33 +448,6 @@ def __call__( return torch.cat(list(t_gen)) -class TorchChunkAverageAggregator(SequenceAggregator[torch.Tensor]): - def __call__(self, tensor_sequence: LazyChunkSequence): - t_gen = tensor_sequence.generator_factory() - result = next(t_gen) - n_chunks = 1 - for t in t_gen: - result += t - n_chunks += 1 - return result / n_chunks - - -class TorchPointAverageAggregator(SequenceAggregator[torch.Tensor]): - def __init__(self, batch_dim: int = 0, weighted: bool = True): - self.weighted = weighted - self.batch_dim = batch_dim - - def __call__(self, tensor_sequence: LazyChunkSequence): - tensor_generator = tensor_sequence.generator_factory() - result = next(tensor_generator) - n_points = result.shape[self.batch_dim] - for tensor in tensor_generator: - n_points_in_batch = tensor.shape[self.batch_dim] - result += n_points_in_batch * tensor if self.weighted else tensor - n_points += n_points_in_batch - return result / n_points - - class NestedTorchCatAggregator(NestedSequenceAggregator[torch.Tensor]): """ An aggregator that concatenates tensors using PyTorch's [torch.cat][torch.cat] @@ -648,7 +618,7 @@ def rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: forming xx^T and sums the result. Here, X and V are matrices where each row represents an individual vector. Effectively it is computing - $$ V@(\sum_i^N x[i]x[i]^T) $$ + $$ V@( \frac{1}{N}\sum_i^N x[i]x[i]^T) $$ Args: x: Matrix of vectors of size `(N, M)`. @@ -661,8 +631,23 @@ def rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """ if v.ndim == 1: result = torch.einsum("ij,kj->ki", x, v.unsqueeze(0)) @ x - return result.squeeze() - return torch.einsum("ij,kj->ki", x, v) @ x + return result.squeeze() / x.shape[0] + return (torch.einsum("ij,kj->ki", x, v) @ x) / x.shape[0] + + +def generate_rank_one_mvp( + x: List[torch.Tensor], v: List[torch.Tensor] +) -> Generator[torch.Tensor, None, None]: + x_v_iterator = zip(x, v) + x_, v_ = next(x_v_iterator) + + nominator = torch.einsum("i..., k...->ki", x_, v_) + + for x_, v_ in x_v_iterator: + nominator += torch.einsum("i..., k...->ki", x_, v_) + + for x_, v_ in zip(x, v): + yield torch.einsum("ji, i... -> j...", nominator, x_) / x_.shape[0] def inverse_rank_one_update( @@ -696,30 +681,64 @@ def inverse_rank_one_update( return (v - (nominator / denominator) @ x) / regularization -def inverse_rank_one_update_dict( - x: Dict[str, torch.Tensor], v: Dict[str, torch.Tensor], regularization: float -) -> Dict[str, torch.Tensor]: +def generate_inverse_rank_one_updates( + x: List[torch.Tensor], v: List[torch.Tensor], regularization: float +) -> Generator[torch.Tensor, None, None]: + def _check_batch_dim(t_x, t_v, idx: int): + if t_x.ndim <= 1: + raise ValueError( + f"Provided tensors in the lists must have at least " + f"2 dimensions, " + f"but found {t_x.ndim=} at {idx=} in list x" + ) + + if v_.ndim <= 1: + raise ValueError( + f"Provided tensors in the lists must have at least " + f"2 dimensions, " + f"but found shape {t_v.ndim=} at {idx=} in list v" + ) + + def _create_dim_error(x_shape, v_shape, idx: int): + return ValueError( + f"Entries in the tensor lists must have the same " + f"(excluding the first batch dimensions), " + f"but found shapes {x_shape} and {v_shape}" + f"at {idx=}" + ) + + if not len(x) == len(v): + raise ValueError( + f"Provided tensor lists must have the same length, but got" + f"{len(x)=} and {len(v)=}" + ) - denominator = regularization - nominator = None - batch_size = None - for x_, v_ in zip(x.values(), v.values()): - if batch_size is None: - batch_size = x_.shape[0] - if nominator is None: - nominator = torch.einsum("i..., k...->ki", x_, v_) - else: - nominator += torch.einsum("i..., k...->ki", x_, v_) + x_v_iterator = enumerate(zip(x, v)) + index, (x_, v_) = next(x_v_iterator) + + _check_batch_dim(x_, v_, index) + + if x_.shape[1:] != v_.shape[1:]: + raise _create_dim_error(x_.shape[1:], v_.shape[1:], index) + + denominator = regularization + torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) + nominator = torch.einsum("i..., k...->ki", x_, v_) + num_data_points = x_.shape[0] + + for k, (x_, v_) in x_v_iterator: + _check_batch_dim(x_, v_, k) + if x_.shape[1:] != v_.shape[1:]: + raise _create_dim_error(x_.shape[1:], v_.shape[1:], k) + + nominator += torch.einsum("i..., k...->ki", x_, v_) denominator += torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) - denominator = batch_size * denominator - result = {} - for key in x.keys(): - result[key] = ( - v[key] - torch.einsum("ji, i... -> j...", nominator / denominator, x[key]) - ) / regularization + denominator = num_data_points * denominator - return result + for x_, v_ in zip(x, v): + yield ( + v_ - torch.einsum("ji, i... -> j...", nominator / denominator, x_) + ) / regularization LossType = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] diff --git a/tests/influence/test_influence_calculator.py b/tests/influence/test_influence_calculator.py index 9a82e89cf..70a29bf1a 100644 --- a/tests/influence/test_influence_calculator.py +++ b/tests/influence/test_influence_calculator.py @@ -1,5 +1,3 @@ -import uuid - import dask.array as da import numpy as np import pytest diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py index 304de4518..a987857e6 100644 --- a/tests/influence/torch/test_batch_operation.py +++ b/tests/influence/torch/test_batch_operation.py @@ -7,7 +7,9 @@ from pydvl.influence.torch.batch_operation import ( GaussNewtonBatchOperation, HessianBatchOperation, + InverseHarmonicMeanBatchOperation, ) +from pydvl.influence.torch.util import align_structure, flatten_dimensions from .test_util import model_data, test_parameters @@ -18,7 +20,7 @@ [(astuple(tp.model_params), 1e-5) for tp in test_parameters], indirect=["model_data"], ) -def test_hessian_batch_operation(model_data, tol: float): +def test_hessian_batch_operation(model_data, tol: float, pytorch_seed): torch_model, x, y, vec, h_analytical = model_data params = dict(torch_model.named_parameters()) @@ -26,9 +28,36 @@ def test_hessian_batch_operation(model_data, tol: float): hessian_op = HessianBatchOperation( torch_model, torch.nn.functional.mse_loss, restrict_to=params ) + batch_size = 10 + rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} + flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) + hvp_autograd_mat_dict = hessian_op.apply_to_tensor_dict( + TorchBatch(x, y), rand_mat_dict + ) + hvp_autograd = hessian_op.apply_to_vec(TorchBatch(x, y), vec) + hvp_autograd_dict = hessian_op.apply_to_tensor_dict( + TorchBatch(x, y), align_structure(params, vec) + ) + hvp_autograd_dict_flat = flatten_dimensions(hvp_autograd_dict.values()) assert torch.allclose(hvp_autograd, h_analytical @ vec, rtol=tol) + assert torch.allclose(hvp_autograd_dict_flat, h_analytical @ vec, rtol=tol) + + op_then_flat = flatten_dimensions( + hvp_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op_analytical = torch.einsum("ik, jk -> ji", h_analytical, flat_rand_mat) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-5, + rtol=tol, + ) + assert torch.allclose( + hessian_op.apply_to_mat(TorchBatch(x, y), flat_rand_mat), op_then_flat + ) @pytest.mark.torch @@ -37,7 +66,7 @@ def test_hessian_batch_operation(model_data, tol: float): [(astuple(tp.model_params), 1e-3) for tp in test_parameters], indirect=["model_data"], ) -def test_gauss_newton_batch_operation(model_data, tol: float): +def test_gauss_newton_batch_operation(model_data, tol: float, pytorch_seed): torch_model, x, y, vec, _ = model_data y_pred = torch_model(x) @@ -47,11 +76,14 @@ def test_gauss_newton_batch_operation(model_data, tol: float): )(x, y_pred, y) dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) - gn_mat_analytical = torch.sum( - torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( - grad_analytical - ), - dim=0, + gn_mat_analytical = ( + torch.sum( + torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( + grad_analytical + ), + dim=0, + ) + / x.shape[0] ) params = dict(torch_model.named_parameters()) @@ -59,8 +91,115 @@ def test_gauss_newton_batch_operation(model_data, tol: float): gn_op = GaussNewtonBatchOperation( torch_model, torch.nn.functional.mse_loss, restrict_to=params ) + batch_size = 10 + + gn_autograd = gn_op.apply_to_vec(TorchBatch(x, y), vec) + gn_autograd_dict = gn_op.apply_to_tensor_dict( + TorchBatch(x, y), align_structure(params, vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical_vec = gn_mat_analytical @ vec + assert torch.allclose(gn_autograd, analytical_vec, atol=1e-5, rtol=tol) + assert torch.allclose(gn_autograd_dict_flat, analytical_vec, atol=1e-5, rtol=tol) + + rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} + flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) + gn_autograd_mat_dict = gn_op.apply_to_tensor_dict(TorchBatch(x, y), rand_mat_dict) + + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = gn_op.apply_to_mat(TorchBatch(x, y), flat_rand_mat) + + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-5, + rtol=tol, + ) + + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", gn_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-5, + rtol=1e-2, + ) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], +) +@pytest.mark.parametrize("reg", [0.4]) +def test_inverse_harmonic_mean_batch_operation( + model_data, tol: float, reg, pytorch_seed +): + torch_model, x, y, vec, _ = model_data + y_pred = torch_model(x) + out_features = y_pred.shape[1] + dl_dw = torch.vmap( + lambda r, s, t: 2 / float(out_features) * (s - t).view(-1, 1) @ r.view(1, -1) + )(x, y_pred, y) + dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) + grad_analytical = torch.cat([dl_dw.reshape(x.shape[0], -1), dl_db], dim=-1) + params = { + k: p.detach() for k, p in torch_model.named_parameters() if p.requires_grad + } + + ihm_mat_analytical = torch.sum( + torch.func.vmap( + lambda z: torch.linalg.inv( + z.unsqueeze(-1) * z.unsqueeze(-1).t() + reg * torch.eye(len(z)) + ) + )(grad_analytical), + dim=0, + ) + ihm_mat_analytical /= x.shape[0] + + gn_op = InverseHarmonicMeanBatchOperation( + torch_model, torch.nn.functional.mse_loss, reg, restrict_to=params + ) + batch_size = 10 + gn_autograd = gn_op.apply_to_vec(TorchBatch(x, y), vec) + gn_autograd_dict = gn_op.apply_to_tensor_dict( + TorchBatch(x, y), align_structure(params, vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical = ihm_mat_analytical @ vec + + assert torch.allclose(gn_autograd, analytical, atol=1e-5, rtol=tol) + assert torch.allclose(gn_autograd_dict_flat, analytical, atol=1e-5, rtol=tol) - gn_analytical = gn_mat_analytical @ vec + rand_mat_dict = {k: torch.randn(batch_size, *t.shape) for k, t in params.items()} + flat_rand_mat = flatten_dimensions(rand_mat_dict.values(), shape=(batch_size, -1)) + gn_autograd_mat_dict = gn_op.apply_to_tensor_dict(TorchBatch(x, y), rand_mat_dict) - assert torch.allclose(gn_autograd, gn_analytical, atol=1e-5, rtol=tol) + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = gn_op.apply_to_mat(TorchBatch(x, y), flat_rand_mat) + + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-5, + rtol=tol, + ) + + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", ihm_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-5, + rtol=1e-2, + ) diff --git a/tests/influence/torch/test_util.py b/tests/influence/torch/test_util.py index fc80ad805..bd11b05e2 100644 --- a/tests/influence/torch/test_util.py +++ b/tests/influence/torch/test_util.py @@ -23,8 +23,9 @@ TorchTensorContainerType, align_structure, flatten_dimensions, + generate_inverse_rank_one_updates, + generate_rank_one_mvp, inverse_rank_one_update, - inverse_rank_one_update_dict, rank_one_mvp, safe_torch_linalg_eigh, torch_dataset_to_dask_array, @@ -338,7 +339,7 @@ def test_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): (torch.vmap(lambda x: x.unsqueeze(-1) * x.unsqueeze(-1).t())(X) @ V.t()) .sum(dim=0) .t() - ) + ) / x_dim_0 result = rank_one_mvp(X, V) @@ -346,6 +347,35 @@ def test_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): assert torch.allclose(result, expected, atol=1e-5, rtol=1e-4) +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_1", + [ + [(4, 2, 3), (5, 7), (5,)], + [(3, 6, 8, 9), (1, 2)], + [(1,)], + ], +) +@pytest.mark.parametrize( + "x_dim_0, v_dim_0", + [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], +) +def test_generate_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): + x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] + v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] + + x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) + v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) + result = rank_one_mvp(x, v) + + inverse_result = flatten_dimensions( + generate_rank_one_mvp(x_list, v_list), + shape=(v_dim_0, -1), + ) + + assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) + + @pytest.mark.torch @pytest.mark.parametrize( "x_dim_0, x_dim_1, v_dim_0", @@ -373,26 +403,31 @@ def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): @pytest.mark.torch @pytest.mark.parametrize( "x_dim_1", - [{"1": (4, 2, 3), "2": (5, 7), "3": ()}, {"1": (3, 6, 8, 9), "2": (1, 2)}, {"1": (1,)}], + [ + [(4, 2, 3), (5, 7), (5,)], + [(3, 6, 8, 9), (1, 2)], + [(1,)], + ], ) @pytest.mark.parametrize( "x_dim_0, v_dim_0", [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], ) @pytest.mark.parametrize("reg", [0.5, 100, 1.0, 10]) -def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): - X_dict = {k: torch.randn(x_dim_0, *d) for k, d in x_dim_1.items()} - V_dict = {k: torch.randn(v_dim_0, *d) for k, d in x_dim_1.items()} +def test_generate_inverse_rank_one_updates(x_dim_0, x_dim_1, v_dim_0, reg): + x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] + v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] - X = flatten_dimensions(X_dict.values(), shape=(x_dim_0, -1)) - V = flatten_dimensions(V_dict.values(), shape=(v_dim_0, -1)) - result = inverse_rank_one_update(X, V, reg) + x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) + v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) + result = inverse_rank_one_update(x, v, reg) inverse_result = flatten_dimensions( - inverse_rank_one_update_dict(X_dict, V_dict, reg).values(), shape=(v_dim_0, -1) + generate_inverse_rank_one_updates(x_list, v_list, reg), + shape=(v_dim_0, -1), ) - assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) + assert torch.allclose(result, inverse_result) class TestModelParameterDictBuilder: