Skip to content

Commit

Permalink
Implement block-diagonal and Gauss-Newton approximation for ArnoldiIn…
Browse files Browse the repository at this point in the history
…fluence, remove

obsolete functions lanczos_low_rank_hessian_approximation, model_hessian_low_rank,
rename parameter rank_estimate -> rank, hessian_regularization -> regularization
  • Loading branch information
schroedk committed Jun 13, 2024
1 parent 5e27519 commit e68a21d
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 346 deletions.
77 changes: 74 additions & 3 deletions src/pydvl/influence/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union, cast
from typing import (
TYPE_CHECKING,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
)

import torch
from torch.func import functional_call
Expand All @@ -12,11 +23,14 @@
from ..base_influence_function_model import ComposableInfluence
from ..types import (
Batch,
BatchType,
BilinearForm,
BlockMapper,
GradientProvider,
GradientProviderType,
Operator,
OperatorGradientComposition,
TensorType,
)
from .util import (
BlockMode,
Expand All @@ -27,6 +41,9 @@
flatten_dimensions,
)

if TYPE_CHECKING:
from .operator import LowRankOperator


@dataclass(frozen=True)
class TorchBatch(Batch):
Expand Down Expand Up @@ -244,7 +261,7 @@ def flat_mixed_grads(self, batch: TorchBatch) -> torch.Tensor:


class OperatorBilinearForm(
BilinearForm[torch.Tensor, TorchBatch, TorchGradientProvider]
BilinearForm[torch.Tensor, TorchBatch, TorchGradientProvider],
):
r"""
Base class for bilinear forms based on an instance of
Expand All @@ -257,7 +274,7 @@ class OperatorBilinearForm(

def __init__(
self,
operator: "TensorOperator",
operator: TorchOperatorType,
):
self.operator = operator

Expand Down Expand Up @@ -406,6 +423,60 @@ def _aggregate_grads(left: torch.Tensor, right: torch.Tensor):
return torch.einsum("i..., j... -> ij", left, right)


class LowRankBilinearForm(OperatorBilinearForm):
def __init__(self, operator: "LowRankOperator"):
super().__init__(operator)

def grads_inner_prod(
self,
left: TorchBatch,
right: Optional[TorchBatch],
gradient_provider: TorchGradientProvider,
) -> torch.Tensor:
r"""
Computes the gradient inner product of two batches of data, i.e.
$$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}),
\nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle_{B}$$
where $\nabla_{\omega}\ell(\omega, \cdot, \cdot)$ is represented by the
`gradient_provider` and the expression must be understood sample-wise.
Args:
left: The first batch for gradient and inner product computation
right: The second batch for gradient and inner product computation,
optional; if not provided, the inner product will use the gradient
computed for `left` for both arguments.
gradient_provider: The gradient provider to compute the gradients.
Returns:
A tensor representing the inner products of the per-sample gradients
"""
op = cast("LowRankOperator", self.operator)

if op.exact:
return super().grads_inner_prod(left, right, gradient_provider)

projections = op.low_rank_representation.projections
eigen_vals = op.low_rank_representation.eigen_vals
regularization = op.regularization

if regularization is not None:
eigen_vals = eigen_vals + regularization

left_grads = gradient_provider.jacobian_prod(left, projections.t())
inverse_regularized_eigenvalues = 1.0 / eigen_vals

if right is None:
right_grads = left_grads
else:
right_grads = gradient_provider.jacobian_prod(right, projections.t())

right_grads = right_grads * inverse_regularized_eigenvalues.unsqueeze(-1)

return torch.einsum("ij, ik -> jk", left_grads, right_grads)


OperatorBilinearFormType = TypeVar(
"OperatorBilinearFormType", bound=OperatorBilinearForm
)
Expand Down
Loading

0 comments on commit e68a21d

Please sign in to comment.