Skip to content

Commit

Permalink
Merge pull request #593 from aai-institute/feature/589-lissa-block-di…
Browse files Browse the repository at this point in the history
…agonal

Feature/589 lissa block diagonal
  • Loading branch information
schroedk authored Jun 10, 2024
2 parents cc3bb2b + 68cd60d commit ea3b86f
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 140 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
- Extend `DirectInfluence` with block-diagonal and Gauss-Newton
approximation
[PR #591](https://github.com/aai-institute/pyDVL/pull/591)
- Extend `LissaInfluence` with block-diagonal and Gauss-Newton approximation
[PR #593](https://github.com/aai-institute/pyDVL/pull/593)

## Changed

Expand All @@ -22,7 +24,12 @@
to `regularization` and change the type annotation to allow
for block-wise regularization parameters
[PR #591](https://github.com/aai-institute/pyDVL/pull/591)

- Rename parameter `hessian_regularization` of `LissaInfluence`
to `regularization` and change the type annotation to allow
for block-wise regularization parameters
[PR #593](https://github.com/aai-institute/pyDVL/pull/593)
- Remove parameter `h0` from init of `LissaInfluence`
[PR #593](https://github.com/aai-institute/pyDVL/pull/593)

## 0.9.2 - 🏗 Bug fixes, logging improvement

Expand Down
16 changes: 10 additions & 6 deletions docs/influence/influence_function_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,29 @@ original paper [@agarwal_secondorder_2017].


```python
from pydvl.influence.torch import LissaInfluence
from pydvl.influence.torch import LissaInfluence, BlockMode, SecondOrderMode
if_model = LissaInfluence(
model,
loss,
hessian_regularization=0.0
regularization=0.0
maxiter=1000,
dampen=0.0,
scale=10.0,
h0=None,
rtol=1e-4,
block_structure=BlockMode.FULL,
second_order_mode=SecondOrderMode.GAUSS_NEWTON
)
if_model.fit(train_loader)
```

with the additional optional parameters `maxiter`, `dampen`, `scale`, `h0`, and
with the additional optional parameters `maxiter`, `dampen`, `scale`, and
`rtol`,
being the maximum number of iterations, the dampening factor, the scaling
factor, the initial guess for the solution and the relative tolerance,
respectively.
factor and the relative tolerance,
respectively. This implementation is capable of using a block-matrix
approximation, see
[Block-diagonal approximation](#block-diagonal-approximation), and can handle
[Gauss-Newton approximation](#gauss-newton-approximation).

### Arnoldi

Expand Down
79 changes: 33 additions & 46 deletions notebooks/influence_wine.ipynb

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions src/pydvl/influence/torch/batch_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def dtype(self):
return next(self.model.parameters()).dtype

@property
def input_size(self):
def input_size(self) -> int:
return sum(p.numel() for p in self.params_to_restrict_to.values())

def to(self, device: torch.device):
Expand Down Expand Up @@ -136,7 +136,7 @@ def apply(self, batch: TorchBatch, tensor: torch.Tensor):
"property `input_size`."
)

if tensor.ndim == 2:
if tensor.ndim == 2 and tensor.shape[0] > 1:
return self._apply_to_mat(batch.to(self.device), tensor.to(self.device))
return self._apply_to_vec(batch.to(self.device), tensor.to(self.device))

Expand All @@ -154,11 +154,14 @@ def _apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor:
$(N, \text{input_size})$
"""
return torch.func.vmap(
result = torch.func.vmap(
lambda _x, _y, m: self._apply_to_vec(TorchBatch(_x, _y), m),
in_dims=(None, None, 0),
randomness="same",
)(batch.x, batch.y, mat)
if result.requires_grad:
result = result.detach()
return result


class HessianBatchOperation(_ModelBasedBatchOperation):
Expand Down Expand Up @@ -194,7 +197,10 @@ def __init__(
self.loss = loss

def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor:
return self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec)
result = self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec)
if result.requires_grad:
result = result.detach()
return result

def _apply_to_dict(
self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor]
Expand Down
146 changes: 64 additions & 82 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
TorchGradientProvider,
TorchOperatorGradientComposition,
)
from .batch_operation import (
BatchOperationType,
GaussNewtonBatchOperation,
HessianBatchOperation,
)
from .functional import (
LowRankProductRepresentation,
create_batch_hvp_function,
Expand All @@ -40,7 +45,7 @@
model_hessian_low_rank,
model_hessian_nystroem_approximation,
)
from .operator import DirectSolveOperator, InverseHarmonicMeanOperator
from .operator import DirectSolveOperator, InverseHarmonicMeanOperator, LissaOperator
from .pre_conditioner import PreConditioner
from .util import (
BlockMode,
Expand Down Expand Up @@ -789,7 +794,7 @@ def to(self, device: torch.device):
return super().to(device)


class LissaInfluence(TorchInfluenceFunctionModel):
class LissaInfluence(TorchComposableInfluence[LissaOperator[BatchOperationType]]):
r"""
Uses LISSA, Linear time Stochastic Second-Order Algorithm, to iteratively
approximate the inverse Hessian. More precisely, it finds x s.t. \(Hx = b\),
Expand All @@ -803,126 +808,103 @@ class LissaInfluence(TorchInfluenceFunctionModel):
see [Linear time Stochastic Second-Order Approximation (LiSSA)]
[linear-time-stochastic-second-order-approximation-lissa]
Args:
model: A PyTorch model. The Hessian will be calculated with respect to
this model's parameters.
loss: A callable that takes the model's output and target as input and returns
the scalar loss.
hessian_regularization: Optional regularization parameter added
regularization: Optional regularization parameter added
to the Hessian-vector product for numerical stability.
maxiter: Maximum number of iterations.
dampen: Dampening factor, defaults to 0 for no dampening.
scale: Scaling factor, defaults to 10.
h0: Initial guess for hvp.
rtol: tolerance to use for early stopping
progress: If True, display progress bars.
warn_on_max_iteration: If True, logs a warning, if the desired tolerance is not
achieved within `maxiter` iterations. If False, the log level for this
information is `logging.DEBUG`
block_structure: The blocking structure, either a pre-defined enum or a
custom block structure, see the information regarding
[block-diagonal approximation][block-diagonal-approximation].
second_order_mode: The second order mode, either `SecondOrderMode.HESSIAN` or
`SecondOrderMode.GAUSS_NEWTON`.
"""

def __init__(
self,
model: nn.Module,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
hessian_regularization: float = 0.0,
regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None,
maxiter: int = 1000,
dampen: float = 0.0,
scale: float = 10.0,
h0: Optional[torch.Tensor] = None,
rtol: float = 1e-4,
progress: bool = False,
warn_on_max_iteration: bool = True,
block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL,
second_order_mode: SecondOrderMode = SecondOrderMode.HESSIAN,
):
super().__init__(model, loss)
self.warn_on_max_iteration = warn_on_max_iteration
super().__init__(model, block_structure, regularization)
self.maxiter = maxiter
self.hessian_regularization = hessian_regularization
self.progress = progress
self.rtol = rtol
self.h0 = h0
self.scale = scale
self.dampen = dampen
self.loss = loss
self.second_order_mode = second_order_mode
self.warn_on_max_iteration = warn_on_max_iteration

train_dataloader: DataLoader
def with_regularization(
self, regularization: Union[float, Dict[str, Optional[float]]]
) -> TorchComposableInfluence:
"""
Update the regularization parameter.
Args:
regularization: Either a positive float or a dictionary with the
block names as keys and the regularization values as values.
@property
def is_fitted(self):
try:
return self.train_dataloader is not None
except AttributeError:
return False
Returns:
The modified instance
@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader) -> LissaInfluence:
self.train_dataloader = data
"""
self._regularization_dict = self._build_regularization_dict(regularization)
for k, reg in self._regularization_dict.items():
self.block_mapper.composable_block_dict[k].op.regularization = reg
return self

@log_duration
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
h_estimate = self.h0 if self.h0 is not None else torch.clone(rhs)

shuffled_training_data = DataLoader(
self.train_dataloader.dataset,
self.train_dataloader.batch_size,
shuffle=True,
)

def lissa_step(
h: torch.Tensor, reg_hvp: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
"""Given an estimate of the hessian inverse and the regularised hessian
vector product, it computes the next estimate.
Args:
h: An estimate of the hessian inverse.
reg_hvp: Regularised hessian vector product.
Returns:
The next estimate of the hessian inverse.
"""
return rhs + (1 - self.dampen) * h - reg_hvp(h) / self.scale

model_params = {
k: p.detach() for k, p in self.model.named_parameters() if p.requires_grad
}
b_hvp = torch.vmap(
create_batch_hvp_function(self.model, self.loss),
in_dims=(None, None, None, 0),
)
for k in tqdm(
range(self.maxiter), disable=not self.progress, desc="Lissa iteration"
):
x, y = next(iter(shuffled_training_data))
x = x.to(self.model_device)
y = y.to(self.model_device)
reg_hvp = (
lambda v: b_hvp(model_params, x, y, v) + self.hessian_regularization * v
def _create_block(
self,
block_params: Dict[str, torch.nn.Parameter],
data: DataLoader,
regularization: Optional[float],
) -> TorchOperatorGradientComposition:
gp = TorchGradientProvider(self.model, self.loss, restrict_to=block_params)
batch_op: Union[GaussNewtonBatchOperation, HessianBatchOperation]
if self.second_order_mode is SecondOrderMode.GAUSS_NEWTON:
batch_op = GaussNewtonBatchOperation(
self.model, self.loss, restrict_to=block_params
)
residual = lissa_step(h_estimate, reg_hvp) - h_estimate
h_estimate += residual
if torch.isnan(h_estimate).any():
raise RuntimeError("NaNs in h_estimate. Increase scale or dampening.")
max_residual = torch.max(torch.abs(residual / h_estimate))
if max_residual < self.rtol:
mean_residual = torch.mean(torch.abs(residual / h_estimate))
logger.debug(
f"Terminated Lissa after {k} iterations with "
f"{max_residual*100:.2f} % max residual and"
f" mean residual {mean_residual*100:.5f} %"
)
break
else:
mean_residual = torch.mean(torch.abs(residual / h_estimate))
log_level = logging.WARNING if self.warn_on_max_iteration else logging.DEBUG
logger.log(
log_level,
f"Reached max number of iterations {self.maxiter} without "
f"achieving the desired tolerance {self.rtol}.\n "
f"Achieved max residual {max_residual*100:.2f} % and"
f" {mean_residual*100:.5f} % mean residual",
batch_op = HessianBatchOperation(
self.model, self.loss, restrict_to=block_params
)
return h_estimate / self.scale
lissa_op = LissaOperator(
batch_op,
data,
regularization,
maxiter=self.maxiter,
dampen=self.dampen,
scale=self.scale,
rtol=self.rtol,
progress=self.progress,
warn_on_max_iteration=self.warn_on_max_iteration,
)
return TorchOperatorGradientComposition(lissa_op, gp)

@property
def is_thread_safe(self) -> bool:
return False


class ArnoldiInfluence(TorchInfluenceFunctionModel):
Expand Down
Loading

0 comments on commit ea3b86f

Please sign in to comment.