diff --git a/trak/score_computers.py b/trak/score_computers.py index 120e163..f82cf56 100644 --- a/trak/score_computers.py +++ b/trak/score_computers.py @@ -122,6 +122,7 @@ def __init__( device: torch.device, CUDA_MAX_DIM_SIZE: int = 20_000, logging_level=logging.INFO, + lambda_reg: float = 0.0, ) -> None: """ Args: @@ -132,11 +133,14 @@ def __init__( Size of block for block-wise matmuls. Defaults to 100_000. logging_level (logging level, optional): Logging level for the logger. Defaults to logging.info. + lambda_reg (int): + regularization term for l2 reg on xtx """ super().__init__(dtype, device) self.CUDA_MAX_DIM_SIZE = CUDA_MAX_DIM_SIZE self.logger = logging.getLogger("ScoreComputer") self.logger.setLevel(logging_level) + self.lambda_reg = lambda_reg def get_xtx(self, grads: Tensor) -> Tensor: self.proj_dim = grads.shape[1] @@ -152,7 +156,11 @@ def get_xtx(self, grads: Tensor) -> Tensor: def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor: blocks = ch.split(grads, split_size_or_sections=self.CUDA_MAX_DIM_SIZE, dim=0) - xtx_inv = ch.linalg.inv(xtx.to(ch.float32)) + + xtx_reg = xtx + self.lambda_reg * torch.eye( + xtx.size(dim=0), device=xtx.device, dtype=xtx.dtype + ) + xtx_inv = ch.linalg.inv(xtx_reg.to(ch.float32)) # center X^TX inverse a bit to avoid numerical issues when going to float16 xtx_inv /= xtx_inv.abs().mean() diff --git a/trak/traker.py b/trak/traker.py index 21e30f1..907935e 100644 --- a/trak/traker.py +++ b/trak/traker.py @@ -64,6 +64,7 @@ def __init__( proj_max_batch_size: int = 32, projector_seed: int = 0, grad_wrt: Optional[Iterable[str]] = None, + lambda_reg: float = 0.0, ) -> None: """ @@ -127,7 +128,10 @@ def __init__( as they appear in the model's state dictionary. If None, gradients are taken with respect to all model parameters. Defaults to None. - + lambda_reg (float): + The :math:`\ell_2` (ridge) regularization penalty added to the + :math:`XTX` term in score computers when computing the matrix + inverse :math:`(XTX)^{-1}`. Defaults to 0. """ self.model = model @@ -136,6 +140,7 @@ def __init__( self.device = device self.dtype = ch.float16 if use_half_precision else ch.float32 self.grad_wrt = grad_wrt + self.lambda_reg = lambda_reg logging.basicConfig() self.logger = logging.getLogger("TRAK") @@ -181,7 +186,10 @@ def __init__( if score_computer is None: score_computer = BasicScoreComputer self.score_computer = score_computer( - dtype=self.dtype, device=self.device, logging_level=logging_level + dtype=self.dtype, + device=self.device, + logging_level=logging_level, + lambda_reg=self.lambda_reg, ) metadata = {