Skip to content

Commit

Permalink
save on I/O overhead by only writing once to disk when scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
kristian-georgiev committed Nov 1, 2023
1 parent cbd3ecb commit 62426eb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
27 changes: 19 additions & 8 deletions trak/score_computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor:
"""

@abstractmethod
def get_scores(self, features: Tensor, target_grads: Tensor) -> Tensor:
def get_scores(
self, features: Tensor, target_grads: Tensor, accumulator: Tensor
) -> None:
"""Computes the scores for a given set of features and target gradients.
In particular, this function takes in a matrix of features
:math:`\Phi=X(X^\top X)^{-1}`, computed by the :code:`get_x_xtx_inv`
Expand All @@ -75,14 +77,17 @@ def get_scores(self, features: Tensor, target_grads: Tensor) -> Tensor:
resulting matrix has shape :code:`(n, m)`, where :math:`n` is the number
of training examples and :math:`m` is the number of target examples.
The :code:`accumulator` argument is used to store the result of the
computation. This is useful when computing scores for multiple model
checkpoints, as it allows us to re-use the same memory for the score
matrix.
Args:
features (Tensor): features :math:`\Phi` of shape :code:`(n, p)`.
target_grads (Tensor):
target projected gradients :math:`X_{target}` of shape
:code:`(m, p)`.
Returns:
Tensor: scores of shape :code:`(n, m)`.
accumulator (Tensor): accumulator of shape :code:`(n, m)`.
"""


Expand All @@ -100,8 +105,10 @@ def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor:
# torch.linalg.inv does not support float16
return grads @ ch.linalg.inv(xtx.float()).to(self.dtype)

def get_scores(self, features: Tensor, target_grads: Tensor) -> Tensor:
return features @ target_grads.T
def get_scores(
self, features: Tensor, target_grads: Tensor, accumulator: Tensor
) -> None:
accumulator += (features @ target_grads.T).detach().cpu()


class BasicScoreComputer(AbstractScoreComputer):
Expand Down Expand Up @@ -161,10 +168,14 @@ def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor:
result[start:end] = block.to(self.device) @ xtx_inv
return result

def get_scores(self, features: Tensor, target_grads: Tensor) -> Tensor:
def get_scores(
self, features: Tensor, target_grads: Tensor, accumulator: Tensor
) -> Tensor:
train_dim = features.shape[0]
target_dim = target_grads.shape[0]

self.logger.debug(f"{train_dim=}, {target_dim=}")

return get_matrix_mult(features=features, target_grads=target_grads)
accumulator += (
get_matrix_mult(features=features, target_grads=target_grads).detach().cpu()
)
23 changes: 10 additions & 13 deletions trak/traker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ def __init__(
logging.basicConfig()
self.logger = logging.getLogger("TRAK")
self.logger.setLevel(logging_level)
self.logger.warning(
"TRAK is still in an early 0.x.x version.\n\
Report any issues at https://github.com/MadryLab/trak/issues"
)

self.num_params = get_num_params(self.model)
# inits self.projector
Expand Down Expand Up @@ -533,8 +529,9 @@ def finalize_scores(
_completed = [False] * len(model_ids)

self.saver.load_current_store(list(model_ids.keys())[0], exp_name, num_targets)
_scores = self.saver.current_store[f"{exp_name}_scores"]
_scores[:] = 0.0
_scores_mmap = self.saver.current_store[f"{exp_name}_scores"]
_scores_on_cpu = ch.zeros(*_scores_mmap.shape, device="cpu")
_scores_on_cpu.pin_memory()

_avg_out_to_losses = np.zeros(
(self.saver.train_set_size, 1),
Expand Down Expand Up @@ -567,21 +564,21 @@ def finalize_scores(
self.saver.current_store[f"{exp_name}_grads"], device=self.device
)

# TODO: do this in-place
_scores[:] += (
self.score_computer.get_scores(g, g_target).cpu().detach().numpy()
)
self.score_computer.get_scores(g, g_target, accumulator=_scores_on_cpu)
# .cpu().detach().numpy()

_avg_out_to_losses += self.saver.current_store["out_to_loss"]
_completed[j] = True

_num_models_used = float(sum(_completed))
_scores[:] = (_scores / _num_models_used) * (

# only write to mmap (on disk) once at the end
_scores_mmap[:] = (_scores_on_cpu.numpy() / _num_models_used) * (
_avg_out_to_losses / _num_models_used
)

self.logger.debug(f"Scores dtype is {_scores.dtype}")
self.logger.debug(f"Scores dtype is {_scores_mmap.dtype}")
self.saver.save_scores(exp_name)
self.scores = _scores
self.scores = _scores_mmap

return self.scores

1 comment on commit 62426eb

@kristian-georgiev
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Co-authored with @AlaaKhaddaj

Please sign in to comment.