Skip to content

Commit

Permalink
Updated traker.py to include a lambda_reg term in arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
heale04 authored Jan 16, 2024
1 parent 1937dea commit 6ce7847
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion trak/traker.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
proj_max_batch_size: int = 32,
projector_seed: int = 0,
grad_wrt: Optional[Iterable[str]] = None,
lambda_reg: int = 0

) -> None:
"""
Expand Down Expand Up @@ -129,6 +131,9 @@ def __init__(
as they appear in the model's state dictionary. If None,
gradients are taken with respect to all model parameters.
Defaults to None.
lambda_reg (int):
Applies L2 regularization to the xtx term in scorecomputers
with form xtx + lambda_reg*1. Defaults to 0
"""

Expand All @@ -138,6 +143,7 @@ def __init__(
self.device = device
self.dtype = ch.float16 if use_half_precision else ch.float32
self.grad_wrt = grad_wrt
self.lambda_reg = lambda_reg

logging.basicConfig()
self.logger = logging.getLogger("TRAK")
Expand Down Expand Up @@ -183,7 +189,7 @@ def __init__(
if score_computer is None:
score_computer = BasicScoreComputer
self.score_computer = score_computer(
dtype=self.dtype, device=self.device, logging_level=logging_level
dtype=self.dtype, device=self.device, logging_level=logging_level, lambda_reg = self.lambda_reg
)

metadata = {
Expand Down

0 comments on commit 6ce7847

Please sign in to comment.