Skip to content

Commit

Permalink
Refactor new structure:
Browse files Browse the repository at this point in the history
* renaming classes
* add cocept 'TensorDictOperator', which can act on tensor dictioniaries, to avoid
  intermeditate flatten and concat to reduce memory consumption
  • Loading branch information
schroedk committed May 29, 2024
1 parent f007507 commit ed5f14d
Show file tree
Hide file tree
Showing 7 changed files with 747 additions and 183 deletions.
204 changes: 193 additions & 11 deletions src/pydvl/influence/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, TypeVar, Union
from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast

import torch
from torch.func import functional_call
Expand All @@ -13,6 +13,7 @@
from ..types import (
Batch,
BilinearForm,
BilinearFormType,
BlockMapper,
GradientProvider,
Operator,
Expand Down Expand Up @@ -47,6 +48,9 @@ class TorchBatch(Batch):
x: torch.Tensor
y: torch.Tensor

def __iter__(self):
return iter((self.x, self.y))

def __post_init__(self):
if self.x.shape[0] != self.y.shape[0]:
raise ValueError(
Expand Down Expand Up @@ -310,7 +314,7 @@ class OperatorBilinearForm(

def __init__(
self,
operator: "TorchOperator",
operator: "TensorOperator",
):
self.operator = operator

Expand Down Expand Up @@ -345,7 +349,128 @@ def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tenso
return torch.einsum("ia,j...a->ij...", left_result, right)


class TorchOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC):
class DictBilinearForm(OperatorBilinearForm):
r"""
Base class for bilinear forms based on an instance of
[TorchOperator][pydvl.influence.torch.operator.base.TorchOperator]. This means it
computes weighted inner products of the form:
$$ \langle \operatorname{Op}(x), y \rangle $$
"""

def __init__(
self,
operator: "TensorDictOperator",
):
super().__init__(operator)

def grads_inner_prod(
self,
left: TorchBatch,
right: Optional[TorchBatch],
gradient_provider: TorchGradientProvider,
) -> torch.Tensor:
r"""
Computes the gradient inner product of two batches of data, i.e.
$$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}),
\nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle_{B}$$
where $\nabla_{\omega}\ell(\omega, \cdot, \cdot)$ is represented by the
`gradient_provider` and the expression must be understood sample-wise.
Args:
left: The first batch for gradient and inner product computation
right: The second batch for gradient and inner product computation,
optional; if not provided, the inner product will use the gradient
computed for `left` for both arguments.
gradient_provider: The gradient provider to compute the gradients.
Returns:
A tensor representing the inner products of the per-sample gradients
"""
operator = cast(TensorDictOperator, self.operator)
left_grads = gradient_provider.grads(left)
if right is None:
right_grads = left_grads
else:
right_grads = gradient_provider.grads(right)

left_batch_size, right_batch_size = next(
(
(l.shape[0], r.shape[0])
for r, l in zip(left_grads.values(), right_grads.values())
)
)

if left_batch_size <= right_batch_size:
left_grads = operator.apply_to_mat_dict(left_grads)
tensor_pairs = zip(left_grads.values(), right_grads.values())
else:
right_grads = operator.apply_to_mat_dict(right_grads)
tensor_pairs = zip(left_grads.values(), right_grads.values())

tensors_to_reduce = (
self._aggregate_grads(left, right) for left, right in tensor_pairs
)

return cast(torch.Tensor, sum(tensors_to_reduce))

def mixed_grads_inner_prod(
self,
left: TorchBatch,
right: TorchBatch,
gradient_provider: TorchGradientProvider,
) -> torch.Tensor:
r"""
Computes the mixed gradient inner product of two batches of data, i.e.
$$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}),
\nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y})
\rangle_{B}$$
where $\nabla_{\omega}\ell(\omega, \cdot)$ and
$\nabla_{\omega}\nabla_{x}\ell(\omega, \cdot)$ are represented by the
`gradient_provider`. The expression must be understood sample-wise.
Args:
left: The first batch for gradient and inner product computation
right: The second batch for gradient and inner product computation
gradient_provider: The gradient provider to compute the gradients.
Returns:
A tensor representing the inner products of the mixed per-sample gradients
"""
operator = cast(TensorDictOperator, self.operator)
right_grads = gradient_provider.mixed_grads(right)
left_grads = gradient_provider.grads(left)
left_grads = operator.apply_to_mat_dict(left_grads)
left_grads_views = (t.reshape(t.shape[0], -1) for t in left_grads.values())
right_grads_views = (
t.reshape(*right.x.shape, -1) for t in right_grads.values()
)
tensor_pairs = zip(left_grads_views, right_grads_views)
tensors_to_reduce = (
self._aggregate_mixed_grads(left, right) for left, right in tensor_pairs
)
return cast(torch.Tensor, sum(tensors_to_reduce))

@staticmethod
def _aggregate_mixed_grads(left: torch.Tensor, right: torch.Tensor):
return torch.einsum("ik, j...k -> ij...", left, right)

@staticmethod
def _aggregate_grads(left: torch.Tensor, right: torch.Tensor):
return torch.einsum("i..., j... -> ij", left, right)


OperatorBilinearFormType = TypeVar(
"OperatorBilinearFormType", bound=OperatorBilinearForm
)


class TensorOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC):
"""
Abstract base class for operators that can be applied to instances of
[torch.Tensor][torch.Tensor].
Expand All @@ -369,13 +494,6 @@ def to(self, device: torch.device):
def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor:
pass

def as_bilinear_form(self):
"""
Represent this operator as a
[OperatorBilinearForm][pydvl.influence.torch.base.OperatorBilinearForm].
"""
return OperatorBilinearForm(self)

def apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor:
"""
Applies the operator to a single vector.
Expand Down Expand Up @@ -403,8 +521,72 @@ def apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor:
"""
return torch.func.vmap(self.apply_to_vec, in_dims=0, randomness="same")(mat)

def as_bilinear_form(self) -> OperatorBilinearForm:
return OperatorBilinearForm(self)


class TensorDictOperator(TensorOperator, ABC):
"""
Abstract base class for operators that can be applied to instances of
[torch.Tensor][torch.Tensor] and compatible dictionaries mapping strings to tensors.
Input dictionaries must conform to the structure defined by the property
`input_dict_structure`. Useful for operators involving autograd functionality
to avoid intermediate flattening and concatenating of gradient inputs.
"""

def apply_to_mat_dict(
self, mat: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Applies the operator to a dictionary of tensors, compatible to the structure
defined by the property `input_dict_structure`.
Args:
mat: dictionary of tensors, whose keys and shapes match the property
`input_dict_structure`.
Returns:
A dictionary of tensors after applying the operator
"""

if not self._validate_mat_dict(mat):
raise ValueError(
f"Incompatible input structure, expected (excluding batch"
f"dimension): \n {self.input_dict_structure}"
)

return self._apply_to_mat_dict(self._dict_to_device(mat))

def _dict_to_device(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return {k: v.to(self.device) for k, v in mat.items()}

@property
@abstractmethod
def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]:
"""
Implement this to expose the expected structure of the input tensor dict, i.e.
a dictionary of shapes (excluding the first batch dimension), in order
to validate the input tensor dicts.
"""

@abstractmethod
def _apply_to_mat_dict(
self, mat: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
pass

def _validate_mat_dict(self, mat: Dict[str, torch.Tensor]) -> bool:
for keys, val in mat.items():
if val.shape[1:] != self.input_dict_structure[keys]:
return False
else:
return True

def as_bilinear_form(self) -> DictBilinearForm:
return DictBilinearForm(self)


TorchOperatorType = TypeVar("TorchOperatorType", bound=TorchOperator)
TorchOperatorType = TypeVar("TorchOperatorType", bound=TensorOperator)


class TorchOperatorGradientComposition(
Expand Down
Loading

0 comments on commit ed5f14d

Please sign in to comment.