diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..1cf14dd --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,3 @@ +# https://black.readthedocs.io/en/stable/guides/introducing_black_to_your_project.html +# Migrate code style to Black +3141015f3687dc11c311f1270c7dff80f1299fe3 diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d7fc2c6..a24bfa8 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -33,11 +33,6 @@ jobs: uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - - name: cuda-toolkit - uses: Jimver/cuda-toolkit@v0.2.8 - id: cuda-toolkit - with: - cuda: '11.7.0' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/README.md b/README.md index 70506c7..11f5fbd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ [![arXiv](https://img.shields.io/badge/arXiv-2303.14186-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2303.14186) [![PyPI version](https://badge.fury.io/py/traker.svg)](https://badge.fury.io/py/traker) [![Documentation Status](https://readthedocs.org/projects/trak/badge/?version=latest)](https://trak.readthedocs.io/en/latest/?badge=latest) +[![Code style: +black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + # TRAK: Attributing Model Behavior at Scale diff --git a/docs/source/bert.rst b/docs/source/bert.rst index 59ed77f..b37d8fc 100644 --- a/docs/source/bert.rst +++ b/docs/source/bert.rst @@ -63,7 +63,7 @@ to fit our API signatures. We slightly redefine the :code:`forward` function so that we can pass in the inputs (:code:`input_ids`, etc.) as positional arguments instead of as keyword arguments. -For data loading, we adapt the code from Hugging Face example: +For data loading, we adapt the code from the HuggingFace example: .. raw:: html @@ -132,7 +132,7 @@ For data loading, we adapt the code from Hugging Face example: # NOTE: CHANGE THIS IF YOU WANT TO RUN ON FULL DATASET TRAIN_SET_SIZE = 5_000 - VAL_SET_SIZE = 1_00 + VAL_SET_SIZE = 10 def init_loaders(batch_size=16): ds_train = get_dataset('train') @@ -180,38 +180,59 @@ The model output function is implemented as follows: .. code-block:: python - def get_output(func_model, - weights: Iterable[Tensor], - buffers: Iterable[Tensor], - input_id: Tensor, - token_type_id: Tensor, - attention_mask: Tensor, - label: Tensor, - ) -> Tensor: - logits = func_model(weights, buffers, input_id.unsqueeze(0), - token_type_id.unsqueeze(0), - attention_mask.unsqueeze(0)) + def get_output( + model, + weights: Iterable[Tensor], + buffers: Iterable[Tensor], + input_id: Tensor, + token_type_id: Tensor, + attention_mask: Tensor, + label: Tensor, + ) -> Tensor: + kw_inputs = { + "input_ids": input_id.unsqueeze(0), + "token_type_ids": token_type_id.unsqueeze(0), + "attention_mask": attention_mask.unsqueeze(0), + } + + logits = ch.func.functional_call( + model, (weights, buffers), args=(), kwargs=kw_inputs + ) bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False) logits_correct = logits[bindex, label.unsqueeze(0)] cloned_logits = logits.clone() - cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf).to(logits.device) + cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor( + -ch.inf, device=logits.device, dtype=logits.dtype + ) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return margins.sum() -The implementation is identical to the standard classification example in :ref:`MODELOUTPUT tutorial`, -except here the signature of the method and the :code:`func_model` is slightly different -as the language model takes in three inputs instead of just one. +The implementation is identical to the standard classification example in +:ref:`MODELOUTPUT tutorial`, except here the signature of the method and the +:code:`func_model` is slightly different as the language model takes in three +inputs instead of just one. Similarly, the gradient function is implemented as follows: .. code-block:: python - def get_out_to_loss_grad(self, func_model, weights, buffers, batch: Iterable[Tensor]) -> Tensor: + def get_out_to_loss_grad( + self, model, weights, buffers, batch: Iterable[Tensor] + ) -> Tensor: input_ids, token_type_ids, attention_mask, labels = batch - logits = func_model(weights, buffers, input_ids, token_type_ids, attention_mask) - ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels] + kw_inputs = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": attention_mask, + } + logits = ch.func.functional_call( + model, (weights, buffers), args=(), kwargs=kw_inputs + ) + ps = self.softmax(logits / self.loss_temperature)[ + ch.arange(logits.size(0)), labels + ] return (1 - ps).clone().detach().unsqueeze(-1) Putting it together @@ -221,12 +242,14 @@ Using the above :code:`TextClassificationModelOutput` implementation, we can com .. code-block:: python - traker = TRAKer(model=model, - task=TextClassificationModelOutput, # you can also just pass in "text_classification" - train_set_size=TRAIN_SET_SIZE, - save_dir=args.out, - device=device, - proj_dim=1024) + traker = TRAKer( + model=model, + task=TextClassificationModelOutput, # you can also just pass in "text_classification" + train_set_size=TRAIN_SET_SIZE, + save_dir=SAVE_DIR, + device=DEVICE, + proj_dim=1024, + ) def process_batch(batch): return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels'] @@ -235,18 +258,21 @@ Using the above :code:`TextClassificationModelOutput` implementation, we can com for batch in tqdm(loader_train, desc='Featurizing..'): # process batch into compatible form for TRAKer TextClassificationModelOutput batch = process_batch(batch) - batch = [x.cuda() for x in batch] + batch = [x.to(DEVICE) for x in batch] traker.featurize(batch=batch, num_samples=batch[0].shape[0]) traker.finalize_features() - traker.start_scoring_checkpoint(model.state_dict(), model_id=0, num_targets=VAL_SET_SIZE) + traker.start_scoring_checkpoint(exp_name='qnli', + checkpoint=model.state_dict(), + model_id=0, + num_targets=VAL_SET_SIZE) for batch in tqdm(loader_val, desc='Scoring..'): batch = process_batch(batch) batch = [x.cuda() for x in batch] traker.score(batch=batch, num_samples=batch[0].shape[0]) - scores = traker.finalize_scores() + scores = traker.finalize_scores(exp_name='qnli') We use :code:`process_batch` to transform the batch from dictionary (which is the form used by Hugging Face dataloaders) to a tuple. @@ -256,4 +282,5 @@ That's all! You can find this tutorial as a complete script in `here =2.0.0", - "numpy", - "tqdm", - ], - extras_require={ - 'tests': - ["assertpy", - "torchvision", - "open_clip_torch", - "wget", - "scipy", - ], - 'fast': - ["fast_jl" - ]}, - include_package_data=True, - ) +setup( + name="traker", + version="0.3.0", + description="TRAK: Attributing Model Behavior at Scale", + long_description="Check https://trak.csail.mit.edu/ to learn more about TRAK", + author="MadryLab", + author_email="trak@mit.edu", + license_files=("LICENSE.txt",), + packages=["trak"], + install_requires=[ + "torch>=2.0.0", + "numpy", + "tqdm", + ], + extras_require={ + "tests": [ + "assertpy", + "torchvision", + "open_clip_torch", + "wget", + "scipy", + "datasets", + "transformers", + ], + "fast": ["fast_jl"], + }, + include_package_data=True, +) diff --git a/tests/autocast.py b/tests/autocast.py index 1feabf3..1771262 100644 --- a/tests/autocast.py +++ b/tests/autocast.py @@ -23,33 +23,33 @@ def compute_loss_autocast(params, inputs, targets): print("1. Without autocast") grads = ch.func.grad(compute_loss)(weights, inputs, targets) - print(f'grads are {grads}') + print(f"grads are {grads}") print(f"grads dtype: {grads['weight'].dtype}") - print('='*50) + print("=" * 50) inputs = inputs.half() targets = targets.half() - print('2. With autocast for forward pass') + print("2. With autocast for forward pass") grads = ch.func.grad(compute_loss_autocast)(weights, inputs, targets) - print(f'grads are {grads}') + print(f"grads are {grads}") print(f"grads dtype: {grads['weight'].dtype}") - print('='*50) + print("=" * 50) - print('3. With autocast for forward pass and backward pass') + print("3. With autocast for forward pass and backward pass") with autocast(device_type="cuda", dtype=ch.float16): grads = ch.func.grad(compute_loss)(weights, inputs, targets) - print(f'inside grads are {grads}') + print(f"inside grads are {grads}") print(f"inside grads dtype: {grads['weight'].dtype}") - print('exiting autocast') - print(f'grads are {grads}') + print("exiting autocast") + print(f"grads are {grads}") print(f"grads dtype: {grads['weight'].dtype}") - print('='*50) + print("=" * 50) - print('4. .half() the model') + print("4. .half() the model") model = model.half() grads = ch.func.grad(compute_loss)(weights, inputs, targets) - print(f'grads are {grads}') + print(f"grads are {grads}") print(f"grads dtype: {grads['weight'].dtype}") """ diff --git a/tests/memory_profiling.py b/tests/memory_profiling.py index 2dc2c4a..6283c03 100644 --- a/tests/memory_profiling.py +++ b/tests/memory_profiling.py @@ -11,35 +11,39 @@ ch = torch -def test_cifar_acc(serialize=False, dtype=ch.float32, batch_size=100, tmp_path='/tmp/trak_results/'): - device = 'cuda:0' +def test_cifar_acc( + serialize=False, dtype=ch.float32, batch_size=100, tmp_path="/tmp/trak_results/" +): + device = "cuda:0" model = construct_rn9().to(memory_format=ch.channels_last).to(device) model = model.eval() - BETONS_PATH = Path(tmp_path).joinpath('cifar_betons') + BETONS_PATH = Path(tmp_path).joinpath("cifar_betons") BETONS = download_cifar_betons(BETONS_PATH) - loader_train = get_dataloader(BETONS, batch_size=batch_size, split='train') - loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val') + loader_train = get_dataloader(BETONS, batch_size=batch_size, split="train") + loader_val = get_dataloader(BETONS, batch_size=batch_size, split="val") - CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts') + CKPT_PATH = Path(tmp_path).joinpath("cifar_ckpts") ckpt_files = download_cifar_checkpoints(CKPT_PATH) - ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files] + ckpts = [ch.load(ckpt, map_location="cpu") for ckpt in ckpt_files] reporter = MemReporter() - traker = TRAKer(model=model, - task='image_classification', - proj_dim=1024, - train_set_size=10_000, - save_dir=tmp_path, - logging_level=logging.DEBUG, - device=device) + traker = TRAKer( + model=model, + task="image_classification", + proj_dim=1024, + train_set_size=10_000, + save_dir=tmp_path, + logging_level=logging.DEBUG, + device=device, + ) for model_id, ckpt in enumerate(ckpts): traker.load_checkpoint(checkpoint=ckpt, model_id=model_id) - for batch in tqdm(loader_train, desc='Computing TRAK embeddings...'): + for batch in tqdm(loader_train, desc="Computing TRAK embeddings..."): traker.featurize(batch=batch, num_samples=len(batch[0])) reporter.report() @@ -47,20 +51,24 @@ def test_cifar_acc(serialize=False, dtype=ch.float32, batch_size=100, tmp_path=' if serialize: del traker - traker = TRAKer(model=model, - task='image_classification', - proj_dim=1024, - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) + traker = TRAKer( + model=model, + task="image_classification", + proj_dim=1024, + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) for model_id, ckpt in enumerate(ckpts): - traker.start_scoring_checkpoint('test_experiment', ckpt, model_id, num_targets=2_000) - for batch in tqdm(loader_val, desc='Scoring...'): + traker.start_scoring_checkpoint( + "test_experiment", ckpt, model_id, num_targets=2_000 + ) + for batch in tqdm(loader_val, desc="Scoring..."): traker.score(batch=batch, num_samples=len(batch[0])) - traker.finalize_scores('test_experiment') + traker.finalize_scores("test_experiment") with LineProfiler(test_cifar_acc, TRAKer.featurize, TRAKer.load_checkpoint) as prof: diff --git a/tests/test_cifar10_accuracy.py b/tests/test_cifar10_accuracy.py index 9b164e5..0b0e4ee 100644 --- a/tests/test_cifar10_accuracy.py +++ b/tests/test_cifar10_accuracy.py @@ -15,30 +15,38 @@ ch = torch -def get_dataloader(batch_size=256, num_workers=8, split='train', shuffle=False, augment=True): +def get_dataloader( + batch_size=256, num_workers=8, split="train", shuffle=False, augment=True +): if augment: transforms = torchvision.transforms.Compose( - [torchvision.transforms.RandomHorizontalFlip(), - torchvision.transforms.RandomAffine(0), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.201))]) + [ + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.RandomAffine(0), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201) + ), + ] + ) else: - transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.201))]) - - is_train = (split == 'train') - dataset = torchvision.datasets.CIFAR10(root='/tmp/cifar/', - download=True, - train=is_train, - transform=transforms) - - loader = torch.utils.data.DataLoader(dataset=dataset, - shuffle=shuffle, - batch_size=batch_size, - num_workers=num_workers) + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201) + ), + ] + ) + + is_train = split == "train" + dataset = torchvision.datasets.CIFAR10( + root="/tmp/cifar/", download=True, train=is_train, transform=transforms + ) + + loader = torch.utils.data.DataLoader( + dataset=dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers + ) return loader @@ -46,46 +54,57 @@ def get_dataloader(batch_size=256, num_workers=8, split='train', shuffle=False, def get_projector(use_cuda_projector, dtype): if use_cuda_projector: return None - return BasicProjector(grad_dim=2274880, proj_dim=2048, - seed=0, proj_type='normal', block_size=400, - dtype=dtype, device='cuda:0') + return BasicProjector( + grad_dim=2274880, + proj_dim=2048, + seed=0, + proj_type="normal", + block_size=400, + dtype=dtype, + device="cuda:0", + ) # reduce the number of tests for CIFAR-10 -PARAM = list(product([False], # serialize - [True], # basic / cuda projector - [ch.float16], # projection dtype - [128], # batch size - )) +PARAM = list( + product( + [False], # serialize + [True], # basic / cuda projector + [ch.float16], # projection dtype + [128], # batch size + ) +) @pytest.mark.parametrize("serialize, use_cuda_projector, dtype, batch_size", PARAM) @pytest.mark.cuda def test_cifar_acc(serialize, use_cuda_projector, dtype, batch_size, tmp_path): - device = 'cuda:0' + device = "cuda:0" projector = get_projector(use_cuda_projector, dtype) model = construct_rn9(10).to(memory_format=ch.channels_last).to(device) model = model.eval() - loader_train = get_dataloader(batch_size=batch_size, split='train', augment=False) - loader_val = get_dataloader(batch_size=batch_size, split='val', augment=False) + loader_train = get_dataloader(batch_size=batch_size, split="train", augment=False) + loader_val = get_dataloader(batch_size=batch_size, split="val", augment=False) - CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts') + CKPT_PATH = Path(tmp_path).joinpath("cifar_ckpts") ckpt_files = download_cifar_checkpoints(CKPT_PATH) - ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files] + ckpts = [ch.load(ckpt, map_location="cpu") for ckpt in ckpt_files] - traker = TRAKer(model=model, - task='image_classification', - projector=projector, - train_set_size=50_000, - save_dir=tmp_path, - logging_level=logging.DEBUG, - device=device) + traker = TRAKer( + model=model, + task="image_classification", + projector=projector, + train_set_size=50_000, + save_dir=tmp_path, + logging_level=logging.DEBUG, + device=device, + ) for model_id, ckpt in enumerate(ckpts): traker.load_checkpoint(checkpoint=ckpt, model_id=model_id) - for batch in tqdm(loader_train, desc='Computing TRAK embeddings...'): + for batch in tqdm(loader_train, desc="Computing TRAK embeddings..."): batch = [x.cuda() for x in batch] traker.featurize(batch=batch, num_samples=len(batch[0])) @@ -93,28 +112,34 @@ def test_cifar_acc(serialize, use_cuda_projector, dtype, batch_size, tmp_path): if serialize: del traker - traker = TRAKer(model=model, - task='image_classification', - projector=projector, - train_set_size=50_000, - save_dir=tmp_path, - logging_level=logging.DEBUG, - device=device) + traker = TRAKer( + model=model, + task="image_classification", + projector=projector, + train_set_size=50_000, + save_dir=tmp_path, + logging_level=logging.DEBUG, + device=device, + ) for model_id, ckpt in enumerate(ckpts): - traker.start_scoring_checkpoint(exp_name='test_experiment', - checkpoint=ckpt, - model_id=model_id, - num_targets=10_000) - for batch in tqdm(loader_val, desc='Scoring...'): + traker.start_scoring_checkpoint( + exp_name="test_experiment", + checkpoint=ckpt, + model_id=model_id, + num_targets=10_000, + ) + for batch in tqdm(loader_val, desc="Scoring..."): batch = [x.cuda() for x in batch] traker.score(batch=batch, num_samples=len(batch[0])) print(traker.saver.experiments) - scores = traker.finalize_scores(exp_name='test_experiment') + scores = traker.finalize_scores(exp_name="test_experiment") print(scores) print(scores.shape) avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path) - assert avg_corr > 0.05, 'correlation with the above 3 CIFAR-10 checkpoints should be >= 0.05' + assert ( + avg_corr > 0.05 + ), "correlation with the above 3 CIFAR-10 checkpoints should be >= 0.05" diff --git a/tests/test_cifar2_accuracy.py b/tests/test_cifar2_accuracy.py index 2c38e85..3dbcc91 100644 --- a/tests/test_cifar2_accuracy.py +++ b/tests/test_cifar2_accuracy.py @@ -15,75 +15,87 @@ def get_projector(use_cuda_projector, dtype): if use_cuda_projector: return None - return BasicProjector(grad_dim=2273856, proj_dim=1024, - seed=0, proj_type='rademacher', - dtype=dtype, device='cuda:0') - - -PARAM = list(product([False, True], # serialize - [False, True], # basic / cuda projector - [ch.float16, ch.float32], # projection dtype - [100, 32], # batch size - )) + return BasicProjector( + grad_dim=2273856, + proj_dim=1024, + seed=0, + proj_type="rademacher", + dtype=dtype, + device="cuda:0", + ) + + +PARAM = list( + product( + [False, True], # serialize + [False, True], # basic / cuda projector + [ch.float16, ch.float32], # projection dtype + [100, 32], # batch size + ) +) @pytest.mark.parametrize("serialize, use_cuda_projector, dtype, batch_size", PARAM) @pytest.mark.cuda def test_cifar_acc(serialize, use_cuda_projector, dtype, batch_size, tmp_path): - device = 'cuda:0' - exp_name = 'test_experimet' + device = "cuda:0" + exp_name = "test_experimet" projector = get_projector(use_cuda_projector, dtype) model = construct_rn9().to(memory_format=ch.channels_last).to(device) model = model.eval() - BETONS_PATH = Path(tmp_path).joinpath('cifar_betons') + BETONS_PATH = Path(tmp_path).joinpath("cifar_betons") BETONS = download_cifar_betons(BETONS_PATH) - loader_train = get_dataloader(BETONS, batch_size=batch_size, split='train') - loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val') + loader_train = get_dataloader(BETONS, batch_size=batch_size, split="train") + loader_val = get_dataloader(BETONS, batch_size=batch_size, split="val") - CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts') - ckpt_files = download_cifar_checkpoints(CKPT_PATH, 'cifar2') - ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files] + CKPT_PATH = Path(tmp_path).joinpath("cifar_ckpts") + ckpt_files = download_cifar_checkpoints(CKPT_PATH, "cifar2") + ckpts = [ch.load(ckpt, map_location="cpu") for ckpt in ckpt_files] - use_half_precision = (dtype == ch.float16) + use_half_precision = dtype == ch.float16 - traker = TRAKer(model=model, - task='image_classification', - projector=projector, - proj_dim=1024, - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG, - use_half_precision=use_half_precision) + traker = TRAKer( + model=model, + task="image_classification", + projector=projector, + proj_dim=1024, + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + use_half_precision=use_half_precision, + ) for model_id, ckpt in enumerate(ckpts): traker.load_checkpoint(checkpoint=ckpt, model_id=model_id) - for batch in tqdm(loader_train, desc='Computing TRAK embeddings...'): + for batch in tqdm(loader_train, desc="Computing TRAK embeddings..."): traker.featurize(batch=batch, num_samples=len(batch[0])) traker.finalize_features() if serialize: del traker - traker = TRAKer(model=model, - task='image_classification', - projector=projector, - proj_dim=1024, - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG, - use_half_precision=use_half_precision) + traker = TRAKer( + model=model, + task="image_classification", + projector=projector, + proj_dim=1024, + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + use_half_precision=use_half_precision, + ) for model_id, ckpt in enumerate(ckpts): traker.start_scoring_checkpoint(exp_name, ckpt, model_id, num_targets=2_000) - for batch in tqdm(loader_val, desc='Scoring...'): + for batch in tqdm(loader_val, desc="Scoring..."): traker.score(batch=batch, num_samples=len(batch[0])) scores = traker.finalize_scores(exp_name) - avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds='cifar2') - assert avg_corr > 0.062, 'correlation with 3 CIFAR-2 models should be >= 0.062' + avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds="cifar2") + assert avg_corr > 0.062, "correlation with 3 CIFAR-2 models should be >= 0.062" diff --git a/tests/test_class.py b/tests/test_class.py index 5e36ab2..b4bf2e3 100644 --- a/tests/test_class.py +++ b/tests/test_class.py @@ -9,58 +9,64 @@ @pytest.fixture def cpu_proj(): - projector = BasicProjector(grad_dim=11689512, - proj_dim=20, - seed=0, - proj_type='rademacher', - device='cpu') + projector = BasicProjector( + grad_dim=11689512, proj_dim=20, seed=0, proj_type="rademacher", device="cpu" + ) return projector def test_class_init_cpu(tmp_path, cpu_proj): model = resnet18() - TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - projector=cpu_proj, - train_set_size=20, - logging_level=logging.DEBUG, - device='cpu') + TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + projector=cpu_proj, + train_set_size=20, + logging_level=logging.DEBUG, + device="cpu", + ) def test_class_init(tmp_path, cpu_proj): model = resnet18() - TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - projector=cpu_proj, - train_set_size=20, - logging_level=logging.DEBUG, - device='cuda:0') + TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + projector=cpu_proj, + train_set_size=20, + logging_level=logging.DEBUG, + device="cuda:0", + ) def test_load_ckpt(tmp_path, cpu_proj): model = resnet18() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - projector=cpu_proj, - train_set_size=20, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + projector=cpu_proj, + train_set_size=20, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) def test_load_ckpt_repeat(tmp_path, cpu_proj): model = resnet18() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - projector=cpu_proj, - train_set_size=20, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + projector=cpu_proj, + train_set_size=20, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.load_checkpoint(ckpt, model_id=1) @@ -71,12 +77,14 @@ def test_featurize(tmp_path): model = resnet18().cuda().eval() N = 32 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.featurize(batch, num_samples=N) @@ -87,13 +95,15 @@ def test_max_batch_size(tmp_path): model = resnet18().cuda().eval() N = 32 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - proj_max_batch_size=16, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + proj_max_batch_size=16, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.featurize(batch, num_samples=N) @@ -103,13 +113,15 @@ def test_class_featurize_cpu(tmp_path, cpu_proj): model = resnet18() N = 5 batch = ch.randn(N, 3, 32, 32), ch.randint(low=0, high=10, size=(N,)) - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - projector=cpu_proj, - train_set_size=N, - logging_level=logging.DEBUG, - device='cpu') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + projector=cpu_proj, + train_set_size=N, + logging_level=logging.DEBUG, + device="cpu", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) @@ -120,13 +132,15 @@ def test_class_featurize_noop(tmp_path): model = resnet18() N = 5 batch = ch.randn(N, 3, 32, 32), ch.randint(low=0, high=10, size=(N,)) - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - projector=NoOpProjector(), - train_set_size=N, - logging_level=logging.DEBUG, - device='cpu') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + projector=NoOpProjector(device="cpu"), + train_set_size=N, + logging_level=logging.DEBUG, + device="cpu", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) @@ -138,14 +152,18 @@ def test_forgot_loading_ckpt(tmp_path): model = resnet18().cuda().eval() N = 5 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') - with pytest.raises(AssertionError, - match='Load a checkpoint using traker.load_checkpoint before featurizing'): + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) + with pytest.raises( + AssertionError, + match="Load a checkpoint using traker.load_checkpoint before featurizing", + ): traker.featurize(batch, num_samples=N) @@ -154,12 +172,14 @@ def test_finalize_features(tmp_path): model = resnet18().cuda().eval() N = 5 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.featurize(batch, num_samples=N) @@ -171,12 +191,14 @@ def test_finalize_features_multiple_ftr(tmp_path): model = resnet18().cuda().eval() N = 10 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.featurize([x[:3] for x in batch], num_samples=3) @@ -190,12 +212,14 @@ def test_finalize_features_multiple_ftr_and_id(tmp_path): model = resnet18().cuda().eval() N = 10 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() for model_id in range(2): traker.load_checkpoint(ckpt, model_id=model_id) @@ -210,17 +234,19 @@ def test_score(tmp_path): model = resnet18().cuda().eval() N = 5 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.featurize(batch, num_samples=N) traker.finalize_features() - traker.start_scoring_checkpoint('test_experiment', ckpt, 0, num_targets=N) + traker.start_scoring_checkpoint("test_experiment", ckpt, 0, num_targets=N) traker.score(batch, num_samples=N) @@ -229,20 +255,22 @@ def test_score_finalize(tmp_path): model = resnet18().cuda().eval() N = 5 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.featurize(batch, num_samples=N) traker.finalize_features() - traker.start_scoring_checkpoint('test_experiment', ckpt, 0, num_targets=N) + traker.start_scoring_checkpoint("test_experiment", ckpt, 0, num_targets=N) traker.score(batch, num_samples=N) - traker.finalize_scores(exp_name='test_experiment') + traker.finalize_scores(exp_name="test_experiment") @pytest.mark.cuda @@ -250,12 +278,14 @@ def test_score_finalize_some_model_ids(tmp_path): model = resnet18().cuda().eval() N = 5 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.featurize(batch, num_samples=N) @@ -264,9 +294,9 @@ def test_score_finalize_some_model_ids(tmp_path): traker.featurize(batch, num_samples=N) traker.finalize_features() - traker.start_scoring_checkpoint('test_experiment', ckpt, 0, num_targets=N) + traker.start_scoring_checkpoint("test_experiment", ckpt, 0, num_targets=N) traker.score(batch, num_samples=N) - traker.finalize_scores(exp_name='test_experiment', model_ids=[0]) + traker.finalize_scores(exp_name="test_experiment", model_ids=[0]) @pytest.mark.cuda @@ -274,12 +304,14 @@ def test_score_finalize_split(tmp_path): model = resnet18().cuda().eval() N = 5 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.featurize(batch, num_samples=N) @@ -288,19 +320,21 @@ def test_score_finalize_split(tmp_path): traker.featurize(batch, num_samples=N) traker.finalize_features() - traker.start_scoring_checkpoint('test_experiment', ckpt, 0, num_targets=N) + traker.start_scoring_checkpoint("test_experiment", ckpt, 0, num_targets=N) traker.score(batch, num_samples=N) - traker.start_scoring_checkpoint('test_experiment', ckpt, 1, num_targets=N) + traker.start_scoring_checkpoint("test_experiment", ckpt, 1, num_targets=N) traker.score(batch, num_samples=N) - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0') - traker.finalize_scores(exp_name='test_experiment') + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + ) + traker.finalize_scores(exp_name="test_experiment") @pytest.mark.cuda @@ -308,29 +342,81 @@ def test_score_finalize_full_precision(tmp_path): model = resnet18().cuda().eval() N = 5 batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() - traker = TRAKer(model=model, - task='image_classification', - save_dir=tmp_path, - train_set_size=N, - logging_level=logging.DEBUG, - device='cuda:0', - use_half_precision=False) + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + use_half_precision=False, + ) ckpt = model.state_dict() traker.load_checkpoint(ckpt, model_id=0) traker.featurize(batch, num_samples=N) traker.finalize_features() - traker.start_scoring_checkpoint('test_experiment', ckpt, 0, num_targets=N) + traker.start_scoring_checkpoint("test_experiment", ckpt, 0, num_targets=N) traker.score(batch, num_samples=N) - traker.finalize_scores(exp_name='test_experiment') + traker.finalize_scores(exp_name="test_experiment") def test_custom_model_output(tmp_path, cpu_proj): model = resnet18() - TRAKer(model=model, - task=ImageClassificationModelOutput(), - save_dir=tmp_path, - projector=cpu_proj, - train_set_size=20, - logging_level=logging.DEBUG, - device='cpu') + TRAKer( + model=model, + task=ImageClassificationModelOutput(), + save_dir=tmp_path, + projector=cpu_proj, + train_set_size=20, + logging_level=logging.DEBUG, + device="cpu", + ) + + +def test_grad_wrt_last_layer(tmp_path): + model = resnet18().eval() + N = 5 + batch = ch.randn(N, 3, 32, 32), ch.randint(low=0, high=10, size=(N,)) + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cpu", + use_half_precision=False, + grad_wrt=["fc.weight", "fc.bias"], + ) + ckpt = model.state_dict() + traker.load_checkpoint(ckpt, model_id=0) + traker.featurize(batch, num_samples=N) + traker.finalize_features() + + traker.start_scoring_checkpoint("test_experiment", ckpt, 0, num_targets=N) + traker.score(batch, num_samples=N) + traker.finalize_scores(exp_name="test_experiment") + + +@pytest.mark.cuda +def test_grad_wrt_last_layer_cuda(tmp_path): + model = resnet18().cuda().eval() + N = 5 + batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() + traker = TRAKer( + model=model, + task="image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cuda:0", + grad_wrt=["fc.weight", "fc.bias"], + ) + ckpt = model.state_dict() + traker.load_checkpoint(ckpt, model_id=0) + traker.featurize(batch, num_samples=N) + traker.finalize_features() + + traker.start_scoring_checkpoint("test_experiment", ckpt, 0, num_targets=N) + traker.score(batch, num_samples=N) + traker.finalize_scores(exp_name="test_experiment") diff --git a/tests/test_integration_cifar.py b/tests/test_integration_cifar.py index cad31c1..0a8a767 100644 --- a/tests/test_integration_cifar.py +++ b/tests/test_integration_cifar.py @@ -8,44 +8,50 @@ from trak.projectors import BasicProjector -def test_cifar10(tmp_path, device='cpu'): - model = models.resnet18(weights='DEFAULT') +def test_cifar10(tmp_path, device="cpu"): + model = models.resnet18(weights="DEFAULT") model.to(device) model.eval() - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) - ds_train = datasets.CIFAR10(root='/tmp', download=True, train=True, transform=transform) + ds_train = datasets.CIFAR10( + root="/tmp", download=True, train=True, transform=transform + ) loader_train = DataLoader(ds_train, batch_size=10, shuffle=False) - if device == 'cpu': + if device == "cpu": # the default CudaProjector does not work on cpu - projector = BasicProjector(grad_dim=sum(x.numel() for x in model.parameters()), - proj_dim=20, - seed=0, - proj_type='rademacher', - device=device) + projector = BasicProjector( + grad_dim=sum(x.numel() for x in model.parameters()), + proj_dim=20, + seed=0, + proj_type="rademacher", + device=device, + ) else: projector = None - traker = TRAKer(model=model, - task='image_classification', - train_set_size=len(ds_train), - projector=projector, - save_dir=tmp_path, - logging_level=logging.DEBUG, - device=device) + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=len(ds_train), + projector=projector, + save_dir=tmp_path, + logging_level=logging.DEBUG, + device=device, + ) ckpts = [model.state_dict(), model.state_dict()] for model_id, ckpt in enumerate(ckpts): traker.load_checkpoint(ckpt, model_id=model_id) - for batch in tqdm(loader_train, desc='Computing TRAK embeddings...'): + for batch in tqdm(loader_train, desc="Computing TRAK embeddings..."): batch = [x.to(device) for x in batch] - traker.featurize(batch=batch, - num_samples=loader_train.batch_size) + traker.featurize(batch=batch, num_samples=loader_train.batch_size) break # a CPU pass takes too long lol @pytest.mark.cuda def test_cifar10_cuda(tmp_path): - test_cifar10(tmp_path, device='cuda:0') + test_cifar10(tmp_path, device="cuda:0") diff --git a/tests/test_integration_clip.py b/tests/test_integration_clip.py index 1556e0d..3375031 100644 --- a/tests/test_integration_clip.py +++ b/tests/test_integration_clip.py @@ -7,29 +7,37 @@ @pytest.mark.cuda -def test_mscoco(tmp_path, device='cuda:0'): - model, _, preprocess = open_clip.create_model_and_transforms('RN50') +def test_mscoco(tmp_path, device="cuda:0"): + model, _, preprocess = open_clip.create_model_and_transforms("RN50") model = model.to(device) model.eval() - tokenizer = open_clip.get_tokenizer('RN50') - - ds_train = datasets.CocoCaptions(root='/path/to/coco2014/images/train2014', - annFile='/path/to/coco2014/annotations/annotations/captions_train2014.json' - ) - - traker = TRAKer(model=model, - task='clip', - save_dir=tmp_path, - train_set_size=len(ds_train), - device=device, - proj_dim=512, - logging_level=logging.DEBUG - ) - - traker.task.get_embeddings(model, ds_train, batch_size=1, size=600, embedding_dim=1024, - preprocess_fn_img=lambda x: preprocess(x).to(device).unsqueeze(0), - preprocess_fn_txt=lambda x: tokenizer(x[0]).to(device)) + tokenizer = open_clip.get_tokenizer("RN50") + + ds_train = datasets.CocoCaptions( + root="/path/to/coco2014/images/train2014", + annFile="/path/to/coco2014/annotations/annotations/captions_train2014.json", + ) + + traker = TRAKer( + model=model, + task="clip", + save_dir=tmp_path, + train_set_size=len(ds_train), + device=device, + proj_dim=512, + logging_level=logging.DEBUG, + ) + + traker.task.get_embeddings( + model, + ds_train, + batch_size=1, + size=600, + embedding_dim=1024, + preprocess_fn_img=lambda x: preprocess(x).to(device).unsqueeze(0), + preprocess_fn_txt=lambda x: tokenizer(x[0]).to(device), + ) traker.load_checkpoint(model.state_dict(), model_id=0) for bind, (img, captions) in enumerate(tqdm(ds_train)): diff --git a/tests/test_integration_qnli.py b/tests/test_integration_qnli.py index 58364a3..1c48293 100644 --- a/tests/test_integration_qnli.py +++ b/tests/test_integration_qnli.py @@ -36,49 +36,52 @@ class SequenceClassificationModel(nn.Module): """ Wrapper for HuggingFace sequence classification models. """ + def __init__(self): super().__init__() self.config = AutoConfig.from_pretrained( - 'bert-base-cased', + "bert-base-cased", num_labels=2, - finetuning_task='qnli', + finetuning_task="qnli", cache_dir=None, - revision='main', + revision="main", use_auth_token=None, ) self.model = AutoModelForSequenceClassification.from_pretrained( - 'bert-base-cased', + "bert-base-cased", config=self.config, cache_dir=None, - revision='main', + revision="main", use_auth_token=None, - ignore_mismatched_sizes=False + ignore_mismatched_sizes=False, ) self.model.eval().cuda() def forward(self, input_ids, token_type_ids, attention_mask): - return self.model(input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask).logits + return self.model( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + ).logits def get_dataset(split, inds=None): raw_datasets = load_dataset( - "glue", - 'qnli', - cache_dir=None, - use_auth_token=None, - ) - sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli'] + "glue", + "qnli", + cache_dir=None, + use_auth_token=None, + ) + sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS["qnli"] tokenizer = AutoTokenizer.from_pretrained( - 'bert-base-cased', + "bert-base-cased", cache_dir=None, use_fast=True, - revision='main', - use_auth_token=False + revision="main", + use_auth_token=False, ) padding = "max_length" @@ -87,9 +90,13 @@ def get_dataset(split, inds=None): def preprocess_function(examples): # Tokenize the texts args = ( - (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + (examples[sentence1_key],) + if sentence2_key is None + else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer( + *args, padding=padding, max_length=max_seq_length, truncation=True ) - result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) return result @@ -100,7 +107,7 @@ def preprocess_function(examples): desc="Running tokenizer on dataset", ) - if split == 'train': + if split == "train": train_dataset = raw_datasets["train"] ds = train_dataset else: @@ -110,42 +117,52 @@ def preprocess_function(examples): def init_loaders(batch_size=10): - ds_train = get_dataset('train') + ds_train = get_dataset("train") ds_train = ds_train.select(range(TRAIN_SET_SIZE)) - ds_val = get_dataset('val') + ds_val = get_dataset("val") ds_val = ds_val.select(range(VAL_SET_SIZE)) - return DataLoader(ds_train, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator), \ - DataLoader(ds_val, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator) + return DataLoader( + ds_train, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator + ), DataLoader( + ds_val, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator + ) def process_batch(batch): - return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels'] + return ( + batch["input_ids"], + batch["token_type_ids"], + batch["attention_mask"], + batch["labels"], + ) # model too large to test on CPU @pytest.mark.cuda -def test_qnli(tmp_path, device='cuda'): +def test_qnli(tmp_path, device="cuda"): loader_train, loader_val = init_loaders() # no need to load model from checkpoint, just testing featurization and scoring model = SequenceClassificationModel() - logger = logging.getLogger('QNLI') + logger = logging.getLogger("QNLI") logger.setLevel(logging.DEBUG) - logger.info(f'Initializing TRAKer with device {device}') - - traker = TRAKer(model=model, - task='text_classification', - train_set_size=TRAIN_SET_SIZE, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG, - proj_dim=512) + logger.info(f"Initializing TRAKer with device {device}") + + traker = TRAKer( + model=model, + task="text_classification", + train_set_size=TRAIN_SET_SIZE, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + proj_dim=512, + ) - logger.info('Loading checkpoint') + logger.info("Loading checkpoint") traker.load_checkpoint(model.state_dict(), model_id=0) - logger.info('Loaded checkpoint') - for batch in tqdm(loader_train, desc='Featurizing..'): + logger.info("Loaded checkpoint") + for batch in tqdm(loader_train, desc="Featurizing.."): # process batch into compatible form for TRAKer TextClassificationModelOutput batch = process_batch(batch) batch = [x.to(device) for x in batch] @@ -153,13 +170,15 @@ def test_qnli(tmp_path, device='cuda'): traker.finalize_features() - traker.start_scoring_checkpoint(exp_name='qnli', - checkpoint=model.state_dict(), - model_id=0, - num_targets=VAL_SET_SIZE) - for batch in tqdm(loader_val, desc='Scoring..'): + traker.start_scoring_checkpoint( + exp_name="qnli", + checkpoint=model.state_dict(), + model_id=0, + num_targets=VAL_SET_SIZE, + ) + for batch in tqdm(loader_val, desc="Scoring.."): batch = process_batch(batch) batch = [x.to(device) for x in batch] traker.score(batch=batch, num_samples=batch[0].shape[0]) - traker.finalize_scores(exp_name='qnli') + traker.finalize_scores(exp_name="qnli") diff --git a/tests/test_jl.py b/tests/test_jl.py index 5e3285e..0939427 100644 --- a/tests/test_jl.py +++ b/tests/test_jl.py @@ -2,114 +2,202 @@ import math from itertools import product import numpy as np -import torch as ch +import torch from torch import testing -from trak.projectors import CudaProjector, ProjectionType +from trak.projectors import CudaProjector, ProjectionType, ChunkedCudaProjector + +ch = torch + + +def get_max_chunk_size( + batch_size: int, +): + max_chunk_size = np.iinfo(np.uint32).max // batch_size + return max_chunk_size + + +def make_input( + input_shape, max_chunk_size, device="cuda", dtype=torch.float32, g_tensor=None +): + if g_tensor is None: + g = testing.make_tensor(*input_shape, device=device, dtype=dtype) + else: + g = g_tensor + _, num_params = input_shape + num_chunks = np.ceil(num_params / max_chunk_size).astype("int32") + g_chunks = ch.chunk(g, num_chunks, dim=1) + result = {} + for i, x in enumerate(g_chunks): + result[i] = x + print(f"Input param group {i} shape: {x.shape}") + + return result + + BasicProjector = CudaProjector MAX_BATCH_SIZE = 32 -PARAM = list(product([0, 1, 10**8], # seed - [ProjectionType.normal, ProjectionType.rademacher], # proj type - [ch.float16, ch.float32], # dtype - [ - (1, 25), - (8, 10_000), - (16, 10_002), - (9, 10_002), - (16, 10_001), - (45, 1049), - (1, int(1e9)), - ], # input shape - [4096, 1024], # proj dim - )) +PARAM = list( + product( + [0, 1, 10**8], # seed + [ProjectionType.normal, ProjectionType.rademacher], # proj type + [ch.float16, ch.float32], # dtype + [ + (1, 25), + (8, 10_000), + (16, 10_002), + (9, 10_002), + (16, 10_001), + (45, 1049), + (1, int(1e9)), + ], # input shape + [4096, 1024], # proj dim + ) +) + +PARAM = list( + product( + [123], # seed + [ProjectionType.rademacher], # proj type + [ch.float32], # dtype + [ + # tests for MAXINT32 overflow + (8, 180645096), # pass: np.prod(shape) < np.iinfo(np.int32).max + (31, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + (32, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + (2, 780645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + ], # input shape + [15_360], # proj dim + ) +) @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_seed_consistency(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_seed_consistency( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that re-running the same projection with the same seed leads to the same result. """ - g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) - proj = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + proj = BasicProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) + batch_size = input_shape[0] + max_chunk_size = get_max_chunk_size(batch_size) + g = make_input(input_shape, max_chunk_size, "cuda:0", dtype) result = proj.project(g, model_id=0) result_again = proj.project(g, model_id=0) testing.assert_close(result, result_again, equal_nan=True) + del g + torch.cuda.empty_cache() + @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_seed_consistency_2(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_seed_consistency_2( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that re-initializing the class and re-running the same projection with the same seed leads to the same result. """ - g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) - proj = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + batch_size = input_shape[0] + max_chunk_size = get_max_chunk_size(batch_size) + g = make_input(input_shape, max_chunk_size, "cuda:0", dtype) + + proj = BasicProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) result = proj.project(g, model_id=0) - proj_again = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + proj_again = BasicProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) result_again = proj_again.project(g, model_id=0) testing.assert_close(result, result_again, equal_nan=True) + del g + torch.cuda.empty_cache() + @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_norm_preservation(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_norm_preservation( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that norms of differences are approximately preserved. """ - g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) - proj = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + batch_size = input_shape[0] + max_chunk_size = get_max_chunk_size(batch_size) + g = make_input(input_shape, max_chunk_size, "cuda:0", dtype) + + rng = np.random.default_rng(seed) + seeds = rng.integers( + low=0, + high=500, + size=len(g), + ) + + param_chunk_sizes = [v.size(1) for v in g.values()] + projector_per_chunk = [ + BasicProjector( + grad_dim=chunk_size, + proj_dim=proj_dim, + seed=seeds[i], + proj_type=proj_type, + max_batch_size=MAX_BATCH_SIZE, + dtype=dtype, + device="cuda:0", + ) + for i, chunk_size in enumerate(param_chunk_sizes) + ] + proj = ChunkedCudaProjector( + projector_per_chunk, + max_chunk_size, + param_chunk_sizes, + batch_size, + "cuda:0", + dtype, + ) p = proj.project(g, model_id=0) @@ -118,6 +206,10 @@ def test_norm_preservation(seed, num_trials = 100 num_successes = 0 + # flatten + g = ch.cat([v for v in g.values()], dim=1) + print(f"Flattened input shape: {g.shape}") + for _ in range(num_trials): i, j = np.random.choice(range(g.shape[0]), size=2) n = (g[i] - g[j]).norm() @@ -129,31 +221,39 @@ def test_norm_preservation(seed, # 35 is an arbitrary constant # if NaN, just give up and count as success if math.isinf(res): - print('aaaaaa') + print("aaaaaa") num_successes += int(res <= 35 * eps * n) assert num_successes >= num_trials * (1 - 3 * delta) # leeway with 2 * + del g + torch.cuda.empty_cache() + @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_prod_preservation(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_prod_preservation( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that dot products are approximately preserved. """ - g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) - proj = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + batch_size = input_shape[0] + max_chunk_size = get_max_chunk_size(batch_size) + g = make_input(input_shape, max_chunk_size, "cuda:0", dtype) + + proj = BasicProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) # check that things break with a garbage matrix # (making sure the constant 15 is reasonable) @@ -166,10 +266,14 @@ def test_prod_preservation(seed, num_trials = 100 num_successes = 0 + # flatten + g = ch.cat([v for v in g.values()], dim=1) + print(f"Flattened input shape: {g.shape}") + for _ in range(num_trials): i, j = np.random.choice(range(g.shape[0]), size=2) - n = (g[i] @ g[j]) - pn = ((p[i] / np.sqrt(input_shape[-1])) @ (p[j] / input_shape[-1])) + n = g[i] @ g[j] + pn = (p[i] / np.sqrt(input_shape[-1])) @ (p[j] / input_shape[-1]) res = (n.abs() - pn.abs()).cpu().abs().item() t = (50 * np.sqrt(proj_dim) * eps * n).abs().item() # if NaN, just give up and count as success @@ -177,144 +281,210 @@ def test_prod_preservation(seed, assert num_successes >= num_trials * (1 - 2 * delta) + del g + torch.cuda.empty_cache() + @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_single_nonzero_feature(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_single_nonzero_feature( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that output takes into account every feature. """ - g = ch.zeros(*input_shape, device='cuda:0', dtype=dtype) + + batch_size = input_shape[0] + max_chunk_size = get_max_chunk_size(batch_size) + g = make_input(input_shape, max_chunk_size, "cuda:0", dtype) + for k in g.keys(): + g[k] = ch.zeros_like(g[k]) + for ind in range(input_shape[0]): - coord = np.random.choice(range(input_shape[1])) + param_group = np.random.choice(range(len(g.keys()))) + coord = np.random.choice(range(g[param_group].size(1))) val = ch.randn(1) - g[ind, coord] = val.item() - - proj = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + g[param_group][ind, coord] = val.item() + + proj = BasicProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) p = proj.project(g, model_id=0) assert (~ch.isclose(p, ch.zeros_like(p))).all().item() @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_first_nonzero_feature(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_first_nonzero_feature( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that output takes into account first features. """ - g = ch.zeros(*input_shape, device='cuda:0', dtype=dtype) - g[:, 0] = 1. - - proj = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + g = ch.zeros(*input_shape, device="cuda:0", dtype=dtype) + g[:, 0] = 1.0 + + batch_size = input_shape[0] + max_chunk_size = get_max_chunk_size(batch_size) + g = make_input(input_shape, max_chunk_size, g_tensor=g) + print(g[0]) + + proj = BasicProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) p = proj.project(g, model_id=0) assert (~ch.isclose(p, ch.zeros_like(p))).all().item() @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_last_nonzero_feature(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_last_nonzero_feature( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that output takes into account last features. """ - g = ch.zeros(*input_shape, device='cuda:0', dtype=dtype) - g[:, -1] = 1. - - proj = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + g = ch.zeros(*input_shape, device="cuda:0", dtype=dtype) + g[:, -1] = 1.0 + + batch_size = input_shape[0] + max_chunk_size = get_max_chunk_size(batch_size) + g = make_input(input_shape, max_chunk_size, g_tensor=g) + print(g[0]) + + proj = BasicProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) p = proj.project(g, model_id=0) assert (~ch.isclose(p, ch.zeros_like(p))).all().item() @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_same_features(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_same_features( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that output is the same for the same features """ - g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) + g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype) g[-1] = g[0] - proj = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + batch_size = input_shape[0] + max_chunk_size = get_max_chunk_size(batch_size) + g = make_input(input_shape, max_chunk_size, g_tensor=g) + for i in range(len(g)): + print(g[i][0] == g[i][-1]) + + rng = np.random.default_rng(seed) + seeds = rng.integers( + low=0, + high=500, + size=len(g), + ) + + param_chunk_sizes = [v.size(1) for v in g.values()] + projector_per_chunk = [ + BasicProjector( + grad_dim=chunk_size, + proj_dim=proj_dim, + seed=seeds[i], + proj_type=proj_type, + max_batch_size=MAX_BATCH_SIZE, + dtype=dtype, + device="cuda:0", + ) + for i, chunk_size in enumerate(param_chunk_sizes) + ] + proj = ChunkedCudaProjector( + projector_per_chunk, + max_chunk_size, + param_chunk_sizes, + batch_size, + "cuda:0", + dtype, + ) p = proj.project(g, model_id=0) assert ch.allclose(p[0], p[-1]) + del g + torch.cuda.empty_cache() + @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_orthogonality(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_orthogonality( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that orthgonality of inputs is approximately presereved whp """ if input_shape[0] == 1: pass else: - proj = BasicProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + proj = BasicProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) num_successes = 0 - num_trials = 100 + num_trials = 10 for _ in range(num_trials): - g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) + g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype) g[-1] -= g[0] @ g[-1] / (g[0].norm() ** 2) * g[0] + + batch_size = input_shape[0] + max_chunk_size = get_max_chunk_size(batch_size) + g = make_input(input_shape, max_chunk_size, g_tensor=g) + p = proj.project(g, model_id=0) if p[0] @ p[-1] < 1e-3: num_successes += 1 assert num_successes > 0.33 * num_trials + + del g + torch.cuda.empty_cache() diff --git a/tests/test_jl_additional.py b/tests/test_jl_additional.py index 90d870d..cd18d41 100644 --- a/tests/test_jl_additional.py +++ b/tests/test_jl_additional.py @@ -8,106 +8,121 @@ MAX_BATCH_SIZE = 32 # TEST CASES 1 -PARAM = list(product([123], # seed - [ProjectionType.rademacher], # proj type - [ch.float32], # dtype - [ - (8, 180645096), # pass: np.prod(shape) < np.iinfo(np.int32).max - (16, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max - (31, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max - (32, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max - (33, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max - (48, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max - (50, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max - ], # input shape - [15_360], # proj dim - )) +PARAM = list( + product( + [123], # seed + [ProjectionType.rademacher], # proj type + [ch.float32], # dtype + [ + (8, 180645096), # pass: np.prod(shape) < np.iinfo(np.int32).max + (16, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (31, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + (32, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + (33, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (48, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (50, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + ], # input shape + [15_360], # proj dim + ) +) # TEST CASES 2 -PARAM = list(product([123], # seed - [ProjectionType.rademacher], # proj type - [ch.float32], # dtype - [ - (1, 780645096), # pass: np.prod(shape) < np.iinfo(np.int32).max - (5, 780645096), # pass: np.prod(shape) > np.iinfo(np.int32).max - (6, 780645096), # pass: np.prod(shape) > np.iinfo(np.int32).max - (7, 780645096), # fail: np.prod(shape) > np.iinfo(np.int32).max - (8, 780645096), # fail: np.prod(shape) > np.iinfo(np.int32).max - ], # input shape - [4_096], # proj dim - )) +PARAM = list( + product( + [123], # seed + [ProjectionType.rademacher], # proj type + [ch.float32], # dtype + [ + (1, 780645096), # pass: np.prod(shape) < np.iinfo(np.int32).max + (5, 780645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (6, 780645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (7, 780645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + (8, 780645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + ], # input shape + [4_096], # proj dim + ) +) # TEST CASES 3 (ONLY for test_same_features_diff_sms) -PARAM = list(product([123], # seed - [ProjectionType.rademacher], # proj type - [ch.float32], # dtype - [ - (32, 100_000), - ], # input shape - [4_096], # proj dim - )) +PARAM = list( + product( + [123], # seed + [ProjectionType.rademacher], # proj type + [ch.float32], # dtype + [ + (32, 100_000), + ], # input shape + [4_096], # proj dim + ) +) + @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_same_features(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_same_features( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that output is the same for the same features """ - g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) + g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype) g[-1] = g[0] - proj = CudaProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + proj = CudaProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) p = proj.project(g, model_id=0) assert ch.allclose(p[0], p[-1]) + @pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) @pytest.mark.cuda -def test_same_features_diff_sms(seed, - proj_type, - dtype, - proj_dim, - input_shape, - ): +def test_same_features_diff_sms( + seed, + proj_type, + dtype, + proj_dim, + input_shape, +): """ Check that output is the same for the same features """ - g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) - + g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype) # project with all SMs available - proj_full_sms = CudaProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + proj_full_sms = CudaProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) p_full_sms = proj_full_sms.project(g, model_id=0) # project with half SMs available - proj_half_sms = CudaProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + proj_half_sms = CudaProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) proj_half_sms.num_sms = max(proj_half_sms.num_sms // 2, 1) p_half_sms = proj_half_sms.project(g, model_id=0) diff --git a/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.py b/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.py index 5a7e380..4edff37 100644 --- a/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.py +++ b/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.py @@ -9,68 +9,78 @@ MAX_BATCH_SIZE = 32 # TEST CASES 1 -PARAM = list(product([123], # seed - [ProjectionType.rademacher], # proj type - [ch.float32], # dtype - [ - (32, 100_000), # pass: np.prod(shape) < np.iinfo(np.int32).max - ], # input shape - [4_096], # proj dim - [108], # num sms - )) +PARAM = list( + product( + [123], # seed + [ProjectionType.rademacher], # proj type + [ch.float32], # dtype + [ + (32, 100_000), # pass: np.prod(shape) < np.iinfo(np.int32).max + ], # input shape + [4_096], # proj dim + [108], # num sms + ) +) -@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim, num_sms", PARAM) +@pytest.mark.parametrize( + "seed, proj_type, dtype, input_shape, proj_dim, num_sms", PARAM +) @pytest.mark.cuda -def test_create_proj(seed, - proj_type, - dtype, - proj_dim, - input_shape, - num_sms, - ): +def test_create_proj( + seed, + proj_type, + dtype, + proj_dim, + input_shape, + num_sms, +): """ Compute the output for each GPU type """ - GPU_NAME = os.environ['GPU_NAME'] - print(f'GPU: {GPU_NAME}') + GPU_NAME = os.environ["GPU_NAME"] + print(f"GPU: {GPU_NAME}") - if os.path.exists(f'./{GPU_NAME}.pt'): - os.remove(f'./{GPU_NAME}.pt') + if os.path.exists(f"./{GPU_NAME}.pt"): + os.remove(f"./{GPU_NAME}.pt") - g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) + g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype) - proj = CudaProjector(grad_dim=input_shape[-1], - proj_dim=proj_dim, - proj_type=proj_type, - seed=seed, - device='cuda:0', - dtype=dtype, - max_batch_size=MAX_BATCH_SIZE - ) + proj = CudaProjector( + grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device="cuda:0", + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE, + ) proj.num_sms = num_sms - print(f'# Projector SMs: {proj.num_sms}') + print(f"# Projector SMs: {proj.num_sms}") p = proj.project(g, model_id=0) - ch.save(p.cpu(), f'./{GPU_NAME}.pt') + ch.save(p.cpu(), f"./{GPU_NAME}.pt") -@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim, num_sms", PARAM) +@pytest.mark.parametrize( + "seed, proj_type, dtype, input_shape, proj_dim, num_sms", PARAM +) @pytest.mark.cuda -def test_same_proj(seed, - proj_type, - dtype, - proj_dim, - input_shape, - num_sms, - ): +def test_same_proj( + seed, + proj_type, + dtype, + proj_dim, + input_shape, + num_sms, +): """ Check that output is the same for different GPUs """ - proj_a100 = ch.load('./A100.pt') - proj_h100 = ch.load('./H100.pt') + proj_a100 = ch.load("./A100.pt") + proj_h100 = ch.load("./H100.pt") - assert ch.allclose(proj_a100, proj_h100), 'GPUs have different projection' \ No newline at end of file + assert ch.allclose(proj_a100, proj_h100), "GPUs have different projection" diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 2b96c25..8c31117 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -12,132 +12,146 @@ @pytest.mark.cuda def test_featurize_and_score_in_parallel(tmp_path): - device = 'cuda:0' + device = "cuda:0" batch_size = 100 model = construct_rn9().to(memory_format=ch.channels_last).to(device) model = model.eval() - BETONS_PATH = Path(tmp_path).joinpath('cifar_betons') + BETONS_PATH = Path(tmp_path).joinpath("cifar_betons") BETONS = download_cifar_betons(BETONS_PATH) - loader_train = get_dataloader(BETONS, batch_size=batch_size, split='train') - loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val') + loader_train = get_dataloader(BETONS, batch_size=batch_size, split="train") + loader_val = get_dataloader(BETONS, batch_size=batch_size, split="val") - CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts') - ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds='cifar2') - ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files] + CKPT_PATH = Path(tmp_path).joinpath("cifar_ckpts") + ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds="cifar2") + ckpts = [ch.load(ckpt, map_location="cpu") for ckpt in ckpt_files] # this should be essentially equivalent to running each # TRAKer in a separate script for model_id, ckpt in enumerate(ckpts): - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) traker.load_checkpoint(checkpoint=ckpt, model_id=model_id) - for batch in tqdm(loader_train, desc='Computing TRAK embeddings...'): + for batch in tqdm(loader_train, desc="Computing TRAK embeddings..."): traker.featurize(batch=batch, num_samples=len(batch[0])) traker.finalize_features() for model_id, ckpt in enumerate(ckpts): - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) - - traker.start_scoring_checkpoint('test_experiment', ckpt, model_id, num_targets=2_000) - for batch in tqdm(loader_val, desc='Scoring...'): + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) + + traker.start_scoring_checkpoint( + "test_experiment", ckpt, model_id, num_targets=2_000 + ) + for batch in tqdm(loader_val, desc="Scoring..."): traker.score(batch=batch, num_samples=len(batch[0])) - scores = traker.finalize_scores(exp_name='test_experiment') + scores = traker.finalize_scores(exp_name="test_experiment") - avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds='cifar2') - assert avg_corr > 0.062, 'correlation with 3 CIFAR-2 models should be >= 0.062' + avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds="cifar2") + assert avg_corr > 0.062, "correlation with 3 CIFAR-2 models should be >= 0.062" @pytest.mark.cuda def test_score_multiple(tmp_path): - device = 'cuda:0' + device = "cuda:0" batch_size = 100 model = construct_rn9().to(memory_format=ch.channels_last).to(device) model = model.eval() - BETONS_PATH = Path(tmp_path).joinpath('cifar_betons') + BETONS_PATH = Path(tmp_path).joinpath("cifar_betons") BETONS = download_cifar_betons(BETONS_PATH) - loader_train = get_dataloader(BETONS, batch_size=batch_size, split='train') - loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val') + loader_train = get_dataloader(BETONS, batch_size=batch_size, split="train") + loader_val = get_dataloader(BETONS, batch_size=batch_size, split="val") - CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts') - ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds='cifar2') - ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files] + CKPT_PATH = Path(tmp_path).joinpath("cifar_ckpts") + ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds="cifar2") + ckpts = [ch.load(ckpt, map_location="cpu") for ckpt in ckpt_files] - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) for model_id, ckpt in enumerate(ckpts): traker.load_checkpoint(checkpoint=ckpt, model_id=model_id) - for batch in tqdm(loader_train, desc='Computing TRAK embeddings...'): + for batch in tqdm(loader_train, desc="Computing TRAK embeddings..."): traker.featurize(batch=batch, num_samples=len(batch[0])) traker.finalize_features() scoring_runs = range(3) for _ in scoring_runs: for model_id, ckpt in enumerate(ckpts): - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) - - traker.start_scoring_checkpoint('test_experiment', ckpt, model_id, num_targets=2_000) - for batch in tqdm(loader_val, desc='Scoring...'): + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) + + traker.start_scoring_checkpoint( + "test_experiment", ckpt, model_id, num_targets=2_000 + ) + for batch in tqdm(loader_val, desc="Scoring..."): traker.score(batch=batch, num_samples=len(batch[0])) - scores = traker.finalize_scores('test_experiment') + scores = traker.finalize_scores("test_experiment") - avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds='cifar2') - assert avg_corr > 0.062, 'correlation with 3 CIFAR-2 models should be >= 0.062' + avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds="cifar2") + assert avg_corr > 0.062, "correlation with 3 CIFAR-2 models should be >= 0.062" @pytest.mark.cuda def test_score_in_shards(tmp_path): - device = 'cuda:0' + device = "cuda:0" batch_size = 100 model = construct_rn9().to(memory_format=ch.channels_last).to(device) model = model.eval() - BETONS_PATH = Path(tmp_path).joinpath('cifar_betons') + BETONS_PATH = Path(tmp_path).joinpath("cifar_betons") BETONS = download_cifar_betons(BETONS_PATH) - loader_train = get_dataloader(BETONS, batch_size=batch_size, split='train') + loader_train = get_dataloader(BETONS, batch_size=batch_size, split="train") - CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts') - ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds='cifar2') - ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files] + CKPT_PATH = Path(tmp_path).joinpath("cifar_ckpts") + ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds="cifar2") + ckpts = [ch.load(ckpt, map_location="cpu") for ckpt in ckpt_files] - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) for model_id, ckpt in enumerate(ckpts): traker.load_checkpoint(checkpoint=ckpt, model_id=model_id) - for batch in tqdm(loader_train, desc='Computing TRAK embeddings...'): + for batch in tqdm(loader_train, desc="Computing TRAK embeddings..."): traker.featurize(batch=batch, num_samples=len(batch[0])) traker.finalize_features() @@ -145,136 +159,163 @@ def test_score_in_shards(tmp_path): # this should be essentially equivalent to scoring each # shard in a separate script for scoring_inds in scoring_shards: - loader_val = get_dataloader(BETONS, batch_size=batch_size, - split='val', indices=scoring_inds) + loader_val = get_dataloader( + BETONS, batch_size=batch_size, split="val", indices=scoring_inds + ) for model_id, ckpt in enumerate(ckpts): - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) - - traker.start_scoring_checkpoint('test_experiment', ckpt, model_id, num_targets=2000) - for batch_idx, batch in enumerate(tqdm(loader_val, desc='Scoring...')): - batch_inds = scoring_inds[batch_idx * batch_size: (batch_idx + 1) * batch_size] + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) + + traker.start_scoring_checkpoint( + "test_experiment", ckpt, model_id, num_targets=2000 + ) + for batch_idx, batch in enumerate(tqdm(loader_val, desc="Scoring...")): + batch_inds = scoring_inds[ + batch_idx * batch_size : (batch_idx + 1) * batch_size + ] traker.score(batch=batch, inds=batch_inds) - scores = traker.finalize_scores('test_experiment') + scores = traker.finalize_scores("test_experiment") - avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds='cifar2') - assert avg_corr > 0.062, 'correlation with 3 CIFAR-2 models should be >= 0.062' + avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds="cifar2") + assert avg_corr > 0.062, "correlation with 3 CIFAR-2 models should be >= 0.062" @pytest.mark.cuda def test_featurize_in_shards(tmp_path): - device = 'cuda:0' + device = "cuda:0" batch_size = 100 model = construct_rn9().to(memory_format=ch.channels_last).to(device) model = model.eval() - BETONS_PATH = Path(tmp_path).joinpath('cifar_betons') + BETONS_PATH = Path(tmp_path).joinpath("cifar_betons") BETONS = download_cifar_betons(BETONS_PATH) - loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val') + loader_val = get_dataloader(BETONS, batch_size=batch_size, split="val") - CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts') - ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds='cifar2') - ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files] + CKPT_PATH = Path(tmp_path).joinpath("cifar_ckpts") + ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds="cifar2") + ckpts = [ch.load(ckpt, map_location="cpu") for ckpt in ckpt_files] # this should be essentially equivalent to featurizing each # shard in a separate script featurizing_shards = [np.arange(5000), np.arange(5000, 10_000)] for featurizing_inds in featurizing_shards: - loader_train = get_dataloader(BETONS, batch_size=batch_size, - split='train', indices=featurizing_inds) - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) + loader_train = get_dataloader( + BETONS, batch_size=batch_size, split="train", indices=featurizing_inds + ) + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) for model_id, ckpt in enumerate(ckpts): traker.load_checkpoint(checkpoint=ckpt, model_id=model_id) - for batch_idx, batch in enumerate(tqdm(loader_train, desc='Computing TRAK embeddings')): - batch_inds = featurizing_inds[batch_idx * batch_size: (batch_idx + 1) * batch_size] + for batch_idx, batch in enumerate( + tqdm(loader_train, desc="Computing TRAK embeddings") + ): + batch_inds = featurizing_inds[ + batch_idx * batch_size : (batch_idx + 1) * batch_size + ] traker.featurize(batch=batch, inds=batch_inds) traker.finalize_features() - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) for model_id, ckpt in enumerate(ckpts): - - traker.start_scoring_checkpoint('test_experiment', ckpt, model_id, num_targets=2_000) - for batch in tqdm(loader_val, desc='Scoring...'): + traker.start_scoring_checkpoint( + "test_experiment", ckpt, model_id, num_targets=2_000 + ) + for batch in tqdm(loader_val, desc="Scoring..."): traker.score(batch=batch, num_samples=len(batch[0])) - scores = traker.finalize_scores('test_experiment') + scores = traker.finalize_scores("test_experiment") - avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds='cifar2') - assert avg_corr > 0.062, 'correlation with 3 CIFAR-2 models should be >= 0.062' + avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds="cifar2") + assert avg_corr > 0.062, "correlation with 3 CIFAR-2 models should be >= 0.062" @pytest.mark.cuda def test_preemption(tmp_path): - device = 'cuda:0' + device = "cuda:0" batch_size = 100 model = construct_rn9().to(memory_format=ch.channels_last).to(device) model = model.eval() - BETONS_PATH = Path(tmp_path).joinpath('cifar_betons') + BETONS_PATH = Path(tmp_path).joinpath("cifar_betons") BETONS = download_cifar_betons(BETONS_PATH) - loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val') + loader_val = get_dataloader(BETONS, batch_size=batch_size, split="val") - CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts') - ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds='cifar2') - ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files] + CKPT_PATH = Path(tmp_path).joinpath("cifar_ckpts") + ckpt_files = download_cifar_checkpoints(CKPT_PATH, ds="cifar2") + ckpts = [ch.load(ckpt, map_location="cpu") for ckpt in ckpt_files] # this should be essentially equivalent to featurizing each # shard in a separate script featurizing_shards = [np.arange(5000), np.arange(10_000)] for featurizing_inds in featurizing_shards: - loader_train = get_dataloader(BETONS, batch_size=batch_size, - split='train', indices=featurizing_inds) - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) + loader_train = get_dataloader( + BETONS, batch_size=batch_size, split="train", indices=featurizing_inds + ) + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) for model_id, ckpt in enumerate(ckpts): traker.load_checkpoint(checkpoint=ckpt, model_id=model_id) - for batch_idx, batch in enumerate(tqdm(loader_train, desc='Computing TRAK embeddings')): - batch_inds = featurizing_inds[batch_idx * batch_size: (batch_idx + 1) * batch_size] + for batch_idx, batch in enumerate( + tqdm(loader_train, desc="Computing TRAK embeddings") + ): + batch_inds = featurizing_inds[ + batch_idx * batch_size : (batch_idx + 1) * batch_size + ] traker.featurize(batch=batch, inds=batch_inds) traker.finalize_features() - traker = TRAKer(model=model, - task='image_classification', - train_set_size=10_000, - save_dir=tmp_path, - device=device, - logging_level=logging.DEBUG) + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=10_000, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + ) for model_id, ckpt in enumerate(ckpts): - - traker.start_scoring_checkpoint('test_experiment', ckpt, model_id, num_targets=2_000) - for batch in tqdm(loader_val, desc='Scoring...'): + traker.start_scoring_checkpoint( + "test_experiment", ckpt, model_id, num_targets=2_000 + ) + for batch in tqdm(loader_val, desc="Scoring..."): traker.score(batch=batch, num_samples=len(batch[0])) - scores = traker.finalize_scores('test_experiment') + scores = traker.finalize_scores("test_experiment") - avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds='cifar2') - assert avg_corr > 0.062, 'correlation with 3 CIFAR-2 models should be >= 0.062' + avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path, ds="cifar2") + assert avg_corr > 0.062, "correlation with 3 CIFAR-2 models should be >= 0.062" diff --git a/tests/test_rademacher.py b/tests/test_rademacher.py index 7a0beb1..e3c5265 100644 --- a/tests/test_rademacher.py +++ b/tests/test_rademacher.py @@ -6,13 +6,13 @@ try: import fast_jl except ModuleNotFoundError: - print('No fast_jl available!') + print("No fast_jl available!") from assertpy import assert_that -bs_error_str = 'CUDA error: too many resources requested for launch\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n' # noqa +bs_error_str = "CUDA error: too many resources requested for launch\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n" # noqa -new_bs_error_str = f'The batch size of the CudaProjector is too large for your GPU. Reduce it by using the max_batch_size argument of the CudaProjector.\nOriginal error: {bs_error_str}' # noqa +new_bs_error_str = f"The batch size of the CudaProjector is too large for your GPU. Reduce it by using the max_batch_size argument of the CudaProjector.\nOriginal error: {bs_error_str}" # noqa PARAM = list(product([8], [1024, 2048], [512, 1024, 2048], [0, 1])) @@ -24,7 +24,9 @@ def test_shape(bs: int, input_size: int, output_size: int, seed: int): print(output_size) input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") - num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + num_sms = ch.cuda.get_device_properties( + ch.cuda.current_device() + ).multi_processor_count try: result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) @@ -45,7 +47,9 @@ def test_running(): output_size = 512 input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") - num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + num_sms = ch.cuda.get_device_properties( + ch.cuda.current_device() + ).multi_processor_count try: result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) @@ -66,7 +70,9 @@ def test_even(): output_size = 1024 input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") - num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + num_sms = ch.cuda.get_device_properties( + ch.cuda.current_device() + ).multi_processor_count try: result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) @@ -87,7 +93,9 @@ def test_odd(): output_size = 2048 input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") - num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + num_sms = ch.cuda.get_device_properties( + ch.cuda.current_device() + ).multi_processor_count try: result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) diff --git a/tests/test_rademacher_additional.py b/tests/test_rademacher_additional.py index 53a4f86..ffd12da 100644 --- a/tests/test_rademacher_additional.py +++ b/tests/test_rademacher_additional.py @@ -6,13 +6,13 @@ try: import fast_jl except ModuleNotFoundError: - print('No fast_jl available!') + print("No fast_jl available!") from assertpy import assert_that -bs_error_str = 'CUDA error: too many resources requested for launch\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n' # noqa +bs_error_str = "CUDA error: too many resources requested for launch\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n" # noqa -new_bs_error_str = f'The batch size of the CudaProjector is too large for your GPU. Reduce it by using the max_batch_size argument of the CudaProjector.\nOriginal error: {bs_error_str}' # noqa +new_bs_error_str = f"The batch size of the CudaProjector is too large for your GPU. Reduce it by using the max_batch_size argument of the CudaProjector.\nOriginal error: {bs_error_str}" # noqa PARAM = list(product([8, 16, 32, 48], [180645096], [2048, 4096, 15_360], [0])) @@ -25,7 +25,9 @@ def test_shape(bs: int, input_size: int, output_size: int, seed: int): print(output_size) input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") - num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + num_sms = ch.cuda.get_device_properties( + ch.cuda.current_device() + ).multi_processor_count try: result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) @@ -46,7 +48,9 @@ def test_running(): output_size = 512 input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") - num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + num_sms = ch.cuda.get_device_properties( + ch.cuda.current_device() + ).multi_processor_count try: result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) @@ -67,7 +71,9 @@ def test_even(): output_size = 1024 input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") - num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + num_sms = ch.cuda.get_device_properties( + ch.cuda.current_device() + ).multi_processor_count try: result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) @@ -88,7 +94,9 @@ def test_odd(): output_size = 2048 input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") - num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + num_sms = ch.cuda.get_device_properties( + ch.cuda.current_device() + ).multi_processor_count try: result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) diff --git a/tests/utils.py b/tests/utils.py index 28a5b99..b3427d1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,61 +7,76 @@ import numpy as np import torch import torchvision + ch = torch try: from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder from ffcv.loader import Loader, OrderOption from ffcv.pipeline.operation import Operation - from ffcv.transforms import RandomHorizontalFlip, Cutout, \ - RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage + from ffcv.transforms import ( + RandomHorizontalFlip, + Cutout, + RandomTranslate, + Convert, + ToDevice, + ToTensor, + ToTorchImage, + ) from ffcv.transforms.common import Squeeze except ImportError: - print('No ffcv installed') - - -STATS = { - 'mean': [125.307, 122.961, 113.8575], - 'std': [51.5865, 50.847, 51.255] -} - - -def get_dataloader(BETONS, - batch_size=256, - num_workers=8, - split='train', # split \in [train, val] - aug_seed=0, - should_augment=True, - indices=None): - label_pipeline: List[Operation] = [IntDecoder(), - ToTensor(), - ToDevice(ch.device('cuda:0')), - Squeeze()] + print("No ffcv installed") + + +STATS = {"mean": [125.307, 122.961, 113.8575], "std": [51.5865, 50.847, 51.255]} + + +def get_dataloader( + BETONS, + batch_size=256, + num_workers=8, + split="train", # split \in [train, val] + aug_seed=0, + should_augment=True, + indices=None, +): + label_pipeline: List[Operation] = [ + IntDecoder(), + ToTensor(), + ToDevice(ch.device("cuda:0")), + Squeeze(), + ] image_pipeline: List[Operation] = [SimpleRGBImageDecoder()] if should_augment: - image_pipeline.extend([ + image_pipeline.extend( + [ RandomHorizontalFlip(), - RandomTranslate(padding=2, fill=tuple(map(int, STATS['mean']))), - Cutout(4, tuple(map(int, STATS['mean']))), - ]) + RandomTranslate(padding=2, fill=tuple(map(int, STATS["mean"]))), + Cutout(4, tuple(map(int, STATS["mean"]))), + ] + ) - image_pipeline.extend([ - ToTensor(), - ToDevice(ch.device('cuda:0'), non_blocking=True), - ToTorchImage(), - Convert(ch.float32), - torchvision.transforms.Normalize(STATS['mean'], STATS['std']), - ]) - - return Loader(BETONS[split], - batch_size=batch_size, - num_workers=num_workers, - order=OrderOption.SEQUENTIAL, - drop_last=False, - seed=aug_seed, - indices=indices, - pipelines={'image': image_pipeline, 'label': label_pipeline}) + image_pipeline.extend( + [ + ToTensor(), + ToDevice(ch.device("cuda:0"), non_blocking=True), + ToTorchImage(), + Convert(ch.float32), + torchvision.transforms.Normalize(STATS["mean"], STATS["std"]), + ] + ) + + return Loader( + BETONS[split], + batch_size=batch_size, + num_workers=num_workers, + order=OrderOption.SEQUENTIAL, + drop_last=False, + seed=aug_seed, + indices=indices, + pipelines={"image": image_pipeline, "label": label_pipeline}, + ) # Resnet9 @@ -75,7 +90,8 @@ def forward(self, x): class Flatten(ch.nn.Module): - def forward(self, x): return x.view(x.size(0), -1) + def forward(self, x): + return x.view(x.size(0), -1) class Residual(ch.nn.Module): @@ -88,13 +104,23 @@ def forward(self, x): def construct_rn9(num_classes=2): - def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1): + def conv_bn( + channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1 + ): return ch.nn.Sequential( - ch.nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, - stride=stride, padding=padding, groups=groups, bias=False), - ch.nn.BatchNorm2d(channels_out), - ch.nn.ReLU(inplace=True) + ch.nn.Conv2d( + channels_in, + channels_out, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ), + ch.nn.BatchNorm2d(channels_out), + ch.nn.ReLU(inplace=True), ) + model = ch.nn.Sequential( conv_bn(3, 64, kernel_size=3, stride=1, padding=1), conv_bn(64, 128, kernel_size=5, stride=2, padding=2), @@ -106,64 +132,65 @@ def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, group ch.nn.AdaptiveMaxPool2d((1, 1)), Flatten(), ch.nn.Linear(128, num_classes, bias=False), - Mul(0.2) + Mul(0.2), ) return model def download_cifar_betons(BETON_PATH): - url_train = 'https://www.dropbox.com/s/0llwyuja7u0s9an/train.beton?dl=1' - url_val = 'https://www.dropbox.com/s/63ef3g8dsq32484/val.beton?dl=1' + url_train = "https://www.dropbox.com/s/0llwyuja7u0s9an/train.beton?dl=1" + url_val = "https://www.dropbox.com/s/63ef3g8dsq32484/val.beton?dl=1" os.makedirs(BETON_PATH, exist_ok=True) - train_path = Path(BETON_PATH).joinpath('cifar_train.beton') + train_path = Path(BETON_PATH).joinpath("cifar_train.beton") wget.download(url_train, out=str(train_path), bar=None) - val_path = Path(BETON_PATH).joinpath('cifar_val.beton') + val_path = Path(BETON_PATH).joinpath("cifar_val.beton") wget.download(url_val, out=str(val_path), bar=None) - return {'train': train_path, - 'val': val_path} + return {"train": train_path, "val": val_path} -def download_cifar_checkpoints(CKPT_PATH, ds='cifar10'): - if ds == 'cifar10': - urls = ['https://www.dropbox.com/s/g2f6mlit151bapk/ckpt_0.pt?dl=1', - 'https://www.dropbox.com/s/0avlz649tmwr7fv/ckpt_1.pt?dl=1', - 'https://www.dropbox.com/s/qafphepxnav2igr/ckpt_2.pt?dl=1' - ] +def download_cifar_checkpoints(CKPT_PATH, ds="cifar10"): + if ds == "cifar10": + urls = [ + "https://www.dropbox.com/s/g2f6mlit151bapk/ckpt_0.pt?dl=1", + "https://www.dropbox.com/s/0avlz649tmwr7fv/ckpt_1.pt?dl=1", + "https://www.dropbox.com/s/qafphepxnav2igr/ckpt_2.pt?dl=1", + ] else: - urls = ['https://www.dropbox.com/s/n2p96rbvdy5xruy/model_sd_97.pt?dl=1', - 'https://www.dropbox.com/s/vljde3qwadaqwbt/model_sd_98.pt?dl=1', - 'https://www.dropbox.com/s/ehwx0u131214uak/model_sd_99.pt?dl=1' - ] + urls = [ + "https://www.dropbox.com/s/n2p96rbvdy5xruy/model_sd_97.pt?dl=1", + "https://www.dropbox.com/s/vljde3qwadaqwbt/model_sd_98.pt?dl=1", + "https://www.dropbox.com/s/ehwx0u131214uak/model_sd_99.pt?dl=1", + ] os.makedirs(CKPT_PATH, exist_ok=True) for ind, url in enumerate(urls): - ckpt_path = Path(CKPT_PATH).joinpath(f'sd_{ind}.pt') + ckpt_path = Path(CKPT_PATH).joinpath(f"sd_{ind}.pt") wget.download(url, out=str(ckpt_path), bar=None) return list(Path(CKPT_PATH).rglob("*.pt")) -def eval_correlations(infls, tmp_path, ds='cifar10'): - if ds == 'cifar10': - masks_url = 'https://www.dropbox.com/s/x76uyen8ffkjfke/mask.npy?dl=1' - margins_url = 'https://www.dropbox.com/s/q1dxoxw78ct7c27/val_margins.npy?dl=1' +def eval_correlations(infls, tmp_path, ds="cifar10"): + if ds == "cifar10": + masks_url = "https://www.dropbox.com/s/x76uyen8ffkjfke/mask.npy?dl=1" + margins_url = "https://www.dropbox.com/s/q1dxoxw78ct7c27/val_margins.npy?dl=1" else: - masks_url = 'https://www.dropbox.com/s/2nmcjaftdavyg0m/mask.npy?dl=1' - margins_url = 'https://www.dropbox.com/s/tc3r3c3kgna2h27/val_margins.npy?dl=1' + masks_url = "https://www.dropbox.com/s/2nmcjaftdavyg0m/mask.npy?dl=1" + margins_url = "https://www.dropbox.com/s/tc3r3c3kgna2h27/val_margins.npy?dl=1" - masks_path = Path(tmp_path).joinpath('mask.npy') + masks_path = Path(tmp_path).joinpath("mask.npy") wget.download(masks_url, out=str(masks_path), bar=None) # num masks, num train samples - masks = ch.as_tensor(np.load(masks_path, mmap_mode='r')).float() + masks = ch.as_tensor(np.load(masks_path, mmap_mode="r")).float() - margins_path = Path(tmp_path).joinpath('val_margins.npy') + margins_path = Path(tmp_path).joinpath("val_margins.npy") wget.download(margins_url, out=str(margins_path), bar=None) # num , num val samples - margins = ch.as_tensor(np.load(margins_path, mmap_mode='r')) + margins = ch.as_tensor(np.load(margins_path, mmap_mode="r")) val_inds = np.arange(2000) preds = masks @ infls @@ -174,5 +201,5 @@ def eval_correlations(infls, tmp_path, ds='cifar10'): rs.append(r) ps.append(p) rs, ps = np.array(rs), np.array(ps) - print(f'Correlation: {rs.mean()} (avg p value {ps.mean()})') + print(f"Correlation: {rs.mean()} (avg p value {ps.mean()})") return rs.mean() diff --git a/trak/__init__.py b/trak/__init__.py index f627868..b7bf43e 100644 --- a/trak/__init__.py +++ b/trak/__init__.py @@ -1,6 +1,5 @@ from .traker import TRAKer from .utils import test_install -__version__ = '0.2.2' +__version__ = "0.3.0" VERSION = __version__ - diff --git a/trak/gradient_computers.py b/trak/gradient_computers.py index 4a3bbdc..b000e8b 100644 --- a/trak/gradient_computers.py +++ b/trak/gradient_computers.py @@ -1,14 +1,32 @@ +""" +Computing features for the TRAK algorithm involves computing (and projecting) +per-sample gradients. This module contains classes that compute these +per-sample gradients. The :code:`AbstractFeatureComputer` class defines the +interface for such gradient computers. Then, we provide two implementations: +- :class:`FunctionalFeatureComputer`: A fast implementation that uses + :code:`torch.func` to vectorize the computation of per-sample gradients, and + thus fully levereage parallelism. +- :class:`IterativeFeatureComputer`: A more naive implementation that only uses + native pytorch operations (i.e. no :code:`torch.func`), and computes per-sample + gradients in a for-loop. This is often much slower than the functional + version, but it is useful if you cannot use :code:`torch.func`, e.g., if you + have an old version of pytorch that does not support it, or if your application + is not supported by :code:`torch.func`. + +""" from abc import ABC, abstractmethod from typing import Iterable, Optional from torch import Tensor -from .utils import vectorize, get_num_params, parameters_to_vector +from .utils import get_num_params, parameters_to_vector from .modelout_functions import AbstractModelOutput +import logging import torch + ch = torch class AbstractGradientComputer(ABC): - """ Implementations of the GradientComputer class should allow for + """Implementations of the GradientComputer class should allow for per-sample gradients. This is behavior is enabled with three methods: - the :meth:`.load_model_params` method, well, loads model parameters. It can @@ -24,14 +42,15 @@ class AbstractGradientComputer(ABC): """ @abstractmethod - def __init__(self, - model: torch.nn.Module, - task: AbstractModelOutput, - grad_dim: Optional[int] = None, - dtype: Optional[torch.dtype] = torch.float16, - device: Optional[torch.device] = 'cuda', - ) -> None: - """ Initializes attributes, nothing too interesting happening. + def __init__( + self, + model: torch.nn.Module, + task: AbstractModelOutput, + grad_dim: Optional[int] = None, + dtype: Optional[torch.dtype] = torch.float16, + device: Optional[torch.device] = "cuda", + ) -> None: + """Initializes attributes, nothing too interesting happening. Args: model (torch.nn.Module): @@ -67,19 +86,32 @@ def compute_loss_grad(self, batch: Iterable[Tensor], batch_size: int) -> Tensor: class FunctionalGradientComputer(AbstractGradientComputer): - def __init__(self, - model: torch.nn.Module, - task: AbstractModelOutput, - grad_dim: int, - dtype: torch.dtype, - device: torch.device) -> None: + def __init__( + self, + model: torch.nn.Module, + task: AbstractModelOutput, + grad_dim: int, + dtype: torch.dtype, + device: torch.device, + grad_wrt: Optional[Iterable[str]] = None, + ) -> None: + """Initializes attributes, and loads model parameters. + + Args: + grad_wrt (list[str], optional): + A list of parameter names for which to keep gradients. If None, + gradients are taken with respect to all model parameters. + Defaults to None. + """ super().__init__(model, task, grad_dim, dtype, device) self.model = model self.num_params = get_num_params(self.model) self.load_model_params(model) + self.grad_wrt = grad_wrt + self.logger = logging.getLogger("GradientComputer") def load_model_params(self, model) -> None: - """ Given a a torch.nn.Module model, inits/updates the (functional) + """Given a a torch.nn.Module model, inits/updates the (functional) weights and buffers. See https://pytorch.org/docs/stable/func.html for more details on :code:`torch.func`'s functional models. @@ -92,7 +124,7 @@ def load_model_params(self, model) -> None: self.func_buffers = dict(model.named_buffers()) def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor: - """ Uses functorch's :code:`vmap` (see + """Uses functorch's :code:`vmap` (see https://pytorch.org/functorch/stable/generated/functorch.vmap.html#functorch.vmap for more details) to vectorize the computations of per-sample gradients. @@ -110,20 +142,21 @@ def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor: """ # taking the gradient wrt weights (second argument of get_output, hence argnums=1) - grads_loss = torch.func.grad(self.modelout_fn.get_output, has_aux=False, argnums=1) - # map over batch dimensions (hence 0 for each batch dimension, and None for model params) - grads = torch.empty(size=(batch[0].shape[0], self.num_params), - dtype=self.dtype, - device=self.device) - - vectorize(torch.func.vmap(grads_loss, - in_dims=(None, None, None, *([0] * len(batch))), - randomness='different')(self.model, - self.func_weights, - self.func_buffers, - *batch), - grads) + grads_loss = torch.func.grad( + self.modelout_fn.get_output, has_aux=False, argnums=1 + ) + # map over batch dimensions (hence 0 for each batch dimension, and None for model params) + grads = torch.func.vmap( + grads_loss, + in_dims=(None, None, None, *([0] * len(batch))), + randomness="different", + )(self.model, self.func_weights, self.func_buffers, *batch) + + if self.grad_wrt is not None: + for param_name in list(grads.keys()): + if param_name not in self.grad_wrt: + del grads[param_name] return grads def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor: @@ -151,28 +184,36 @@ def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor: batch of data """ - return self.modelout_fn.get_out_to_loss_grad(self.model, - self.func_weights, - self.func_buffers, - batch) + return self.modelout_fn.get_out_to_loss_grad( + self.model, self.func_weights, self.func_buffers, batch + ) class IterativeGradientComputer(AbstractGradientComputer): - def __init__(self, - model, - task: AbstractModelOutput, - grad_dim: int, - dtype: torch.dtype, - device: torch.device) -> None: + def __init__( + self, + model, + task: AbstractModelOutput, + grad_dim: int, + dtype: torch.dtype, + device: torch.device, + grad_wrt: Optional[Iterable[str]] = None, + ) -> None: super().__init__(model, task, grad_dim, dtype, device) self.load_model_params(model) + self.grad_wrt = grad_wrt + self.logger = logging.getLogger("GradientComputer") + if self.grad_wrt is not None: + self.logger.warning( + "IterativeGradientComputer: ignoring grad_wrt argument." + ) def load_model_params(self, model) -> Tensor: self.model = model self.model_params = list(self.model.parameters()) def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor: - """ Computes per-sample gradients of the model output function This + """Computes per-sample gradients of the model output function This method does not leverage vectorization (and is hence much slower than its equivalent in :class:`.FunctionalGradientComputer`). We recommend that you use this only if :code:`torch.func` is not available to you, @@ -190,9 +231,9 @@ def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor: margin = self.modelout_fn.get_output(self.model, *batch) for ind in range(batch_size): - grads[ind] = parameters_to_vector(ch.autograd.grad(margin[ind], - self.model_params, - retain_graph=True)) + grads[ind] = parameters_to_vector( + ch.autograd.grad(margin[ind], self.model_params, retain_graph=True) + ) return grads def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor: diff --git a/trak/modelout_functions.py b/trak/modelout_functions.py index 8a0b6e6..47ef9b7 100644 --- a/trak/modelout_functions.py +++ b/trak/modelout_functions.py @@ -23,7 +23,7 @@ class AbstractModelOutput(ABC): - """ See, e.g. `this tutorial `_ + """See, e.g. `this tutorial `_ for an example on how to subclass :code:`AbstractModelOutput` for a task of your choice. @@ -40,15 +40,14 @@ class AbstractModelOutput(ABC): the diagonal of) :math:`Q`. """ + @abstractmethod def __init__(self) -> None: pass @abstractmethod - def get_output(self, - model, - batch: Iterable[Tensor]) -> Tensor: - """ See Sections 2 & 3 of `our paper + def get_output(self, model, batch: Iterable[Tensor]) -> Tensor: + """See Sections 2 & 3 of `our paper `_ for more details on what model output functions are in the context of TRAK and how to use & design them. @@ -66,10 +65,8 @@ def get_output(self, ... @abstractmethod - def get_out_to_loss_grad(self, - model, - batch: Iterable[Tensor]) -> Tensor: - """ See Sections 2 & 3 of `our paper + def get_out_to_loss_grad(self, model, batch: Iterable[Tensor]) -> Tensor: + """See Sections 2 & 3 of `our paper `_ for more details on what the out-to-loss functions (in the notation of the paper, :math:`Q`) are in the context of TRAK and how to use & design them. @@ -85,11 +82,11 @@ def get_out_to_loss_grad(self, class ImageClassificationModelOutput(AbstractModelOutput): - """ Margin for (multiclass) image classification. See Section 3.3 of `our + """Margin for (multiclass) image classification. See Section 3.3 of `our paper `_ for more details. """ - def __init__(self, temperature: float = 1.) -> None: + def __init__(self, temperature: float = 1.0) -> None: """ Args: temperature (float, optional): Temperature to use inside the @@ -100,12 +97,14 @@ def __init__(self, temperature: float = 1.) -> None: self.loss_temperature = temperature @staticmethod - def get_output(model: Module, - weights: Iterable[Tensor], - buffers: Iterable[Tensor], - image: Tensor, - label: Tensor) -> Tensor: - """ For a given input :math:`z=(x, y)` and model parameters :math:`\\theta`, + def get_output( + model: Module, + weights: Iterable[Tensor], + buffers: Iterable[Tensor], + image: Tensor, + label: Tensor, + ) -> Tensor: + """For a given input :math:`z=(x, y)` and model parameters :math:`\\theta`, let :math:`p(z, \\theta)` be the softmax probability of the correct class. This method implements the model output function @@ -143,13 +142,17 @@ def get_output(model: Module, cloned_logits = logits.clone() # remove the logits of the correct labels from the sum # in logsumexp by setting to -ch.inf - cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf, device=logits.device, dtype=logits.dtype) + cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor( + -ch.inf, device=logits.device, dtype=logits.dtype + ) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return margins.sum() - def get_out_to_loss_grad(self, model, weights, buffers, batch: Iterable[Tensor]) -> Tensor: - """ Computes the (reweighting term Q in the paper) + def get_out_to_loss_grad( + self, model, weights, buffers, batch: Iterable[Tensor] + ) -> Tensor: + """Computes the (reweighting term Q in the paper) Args: model (torch.nn.Module): @@ -169,12 +172,14 @@ def get_out_to_loss_grad(self, model, weights, buffers, batch: Iterable[Tensor]) logits = ch.func.functional_call(model, (weights, buffers), images) # here we are directly implementing the gradient instead of relying on autodiff to do # that for us - ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels] + ps = self.softmax(logits / self.loss_temperature)[ + ch.arange(logits.size(0)), labels + ] return (1 - ps).clone().detach().unsqueeze(-1) class CLIPModelOutput(AbstractModelOutput): - """ Margin for multimodal contrastive learning (CLIP). See Section 5.1 of + """Margin for multimodal contrastive learning (CLIP). See Section 5.1 of `our paper `_ for more details. Compatible with the open_clip implementation of CLIP. @@ -183,12 +188,15 @@ class CLIPModelOutput(AbstractModelOutput): CLIP embeddings, which are computed using the :func:`get_embeddings` method. This method should be invoked before featurizing. """ + num_computed_embeddings = 0 sim_batch_size = 0 image_embeddings = None text_embeddings = None - def __init__(self, temperature: float = None, simulated_batch_size: int = 300) -> None: + def __init__( + self, temperature: float = None, simulated_batch_size: int = 300 + ) -> None: """ Args: @@ -212,15 +220,16 @@ def __init__(self, temperature: float = None, simulated_batch_size: int = 300) - CLIPModelOutput.sim_batch_size = simulated_batch_size @staticmethod - def get_embeddings(model, - loader, - batch_size: int, - embedding_dim: int, - size: int = 50_000, - preprocess_fn_img=None, - preprocess_fn_txt=None, - ) -> None: - """ Computes (image and text) embeddings and saves them in the class + def get_embeddings( + model, + loader, + batch_size: int, + embedding_dim: int, + size: int = 50_000, + preprocess_fn_img=None, + preprocess_fn_txt=None, + ) -> None: + """Computes (image and text) embeddings and saves them in the class attributes :code:`image_embeddings` and :code:`text_embeddings`. Args: @@ -242,8 +251,10 @@ def get_embeddings(model, pass. Defaults to None. """ - img_embs, txt_embs = ch.zeros(size, embedding_dim).cuda(),\ - ch.zeros(size, embedding_dim).cuda() + img_embs, txt_embs = ( + ch.zeros(size, embedding_dim).cuda(), + ch.zeros(size, embedding_dim).cuda(), + ) cutoff = batch_size with ch.no_grad(): @@ -256,8 +267,8 @@ def get_embeddings(model, if ed == size: cutoff = size - ind * batch_size image_embeddings, text_embeddings, _ = model(images, text) - img_embs[st: ed] = image_embeddings[:cutoff].clone().detach() - txt_embs[st: ed] = text_embeddings[:cutoff].clone().detach() + img_embs[st:ed] = image_embeddings[:cutoff].clone().detach() + txt_embs[st:ed] = text_embeddings[:cutoff].clone().detach() if (ind + 1) * batch_size >= size: break @@ -266,12 +277,14 @@ def get_embeddings(model, CLIPModelOutput.num_computed_embeddings = size @staticmethod - def get_output(model: Module, - weights: Iterable[Tensor], - buffers: Iterable[Tensor], - image: Tensor, - label: Tensor) -> Tensor: - """ For a given input :math:`z=(x, y)` and model parameters + def get_output( + model: Module, + weights: Iterable[Tensor], + buffers: Iterable[Tensor], + image: Tensor, + label: Tensor, + ) -> Tensor: + """For a given input :math:`z=(x, y)` and model parameters :math:`\\theta`, let :math:`\\phi(x, \\theta)` be the CLIP image embedding and :math:`\\psi(y, \\theta)` be the CLIP text embedding. Last, let :math:`B` be a (simulated) batch. This method implements the @@ -313,26 +326,32 @@ def get_output(model: Module, sim_bs = CLIPModelOutput.sim_batch_size if all_im_embs is None: - raise AssertionError('Run traker.task.get_embeddings first before featurizing!') + raise AssertionError( + "Run traker.task.get_embeddings first before featurizing!" + ) # tailored for open_clip # https://github.com/mlfoundations/open_clip/blob/fb72f4db1b17133befd6c67c9cf32a533b85a321/src/open_clip/model.py#L242-L245 - clip_inputs = {'image': image.unsqueeze(0), 'text': label.unsqueeze(0)} - image_embeddings, text_embeddings, _ = ch.func.functional_call(model, - (weights, buffers), - args=(), - kwargs=clip_inputs) - - ii = ch.multinomial(input=ch.arange(N).float(), - num_samples=sim_bs, - replacement=False) - - result = -ch.logsumexp(-image_embeddings @ (text_embeddings - all_txt_embs[ii]).T, dim=1) +\ - -ch.logsumexp(-text_embeddings @ (image_embeddings - all_im_embs[ii]).T, dim=1) + clip_inputs = {"image": image.unsqueeze(0), "text": label.unsqueeze(0)} + image_embeddings, text_embeddings, _ = ch.func.functional_call( + model, (weights, buffers), args=(), kwargs=clip_inputs + ) + + ii = ch.multinomial( + input=ch.arange(N).float(), num_samples=sim_bs, replacement=False + ) + + result = -ch.logsumexp( + -image_embeddings @ (text_embeddings - all_txt_embs[ii]).T, dim=1 + ) + -ch.logsumexp( + -text_embeddings @ (image_embeddings - all_im_embs[ii]).T, dim=1 + ) return result.sum() # shape of result should be [1] - def get_out_to_loss_grad(self, model, weights, buffers, batch: Iterable[Tensor]) -> Tensor: - """ Computes the (reweighting term Q in the paper) + def get_out_to_loss_grad( + self, model, weights, buffers, batch: Iterable[Tensor] + ) -> Tensor: + """Computes the (reweighting term Q in the paper) Args: model (torch.nn.Module): @@ -350,20 +369,19 @@ def get_out_to_loss_grad(self, model, weights, buffers, batch: Iterable[Tensor]) """ image, label = batch - clip_inputs = {'image': image, 'text': label} - image_embeddings, text_embeddings, temp = ch.func.functional_call(model, - (weights, buffers), - args=(), - kwargs=clip_inputs) + clip_inputs = {"image": image, "text": label} + image_embeddings, text_embeddings, temp = ch.func.functional_call( + model, (weights, buffers), args=(), kwargs=clip_inputs + ) if self.temperature is None: self.temperature = temp res = self.temperature * image_embeddings @ text_embeddings.T - ps = (self.softmax(res) + self.softmax(res.T)).diag() / 2. + ps = (self.softmax(res) + self.softmax(res.T)).diag() / 2.0 return (1 - ps).clone().detach() class TextClassificationModelOutput(AbstractModelOutput): - """ Margin for text classification models. This assumes that the model takes + """Margin for text classification models. This assumes that the model takes in input_ids, token_type_ids, and attention_mask. .. math:: @@ -373,52 +391,61 @@ class TextClassificationModelOutput(AbstractModelOutput): """ - def __init__(self, temperature=1.) -> None: + def __init__(self, temperature=1.0) -> None: super().__init__() self.softmax = ch.nn.Softmax(-1) self.loss_temperature = temperature @staticmethod - def get_output(model, - weights: Iterable[Tensor], - buffers: Iterable[Tensor], - input_id: Tensor, - token_type_id: Tensor, - attention_mask: Tensor, - label: Tensor, - ) -> Tensor: - kw_inputs = {'input_ids': input_id.unsqueeze(0), - 'token_type_ids': token_type_id.unsqueeze(0), - 'attention_mask': attention_mask.unsqueeze(0)} - - logits = ch.func.functional_call(model, - (weights, buffers), - args=(), - kwargs=kw_inputs) + def get_output( + model, + weights: Iterable[Tensor], + buffers: Iterable[Tensor], + input_id: Tensor, + token_type_id: Tensor, + attention_mask: Tensor, + label: Tensor, + ) -> Tensor: + kw_inputs = { + "input_ids": input_id.unsqueeze(0), + "token_type_ids": token_type_id.unsqueeze(0), + "attention_mask": attention_mask.unsqueeze(0), + } + + logits = ch.func.functional_call( + model, (weights, buffers), args=(), kwargs=kw_inputs + ) bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False) logits_correct = logits[bindex, label.unsqueeze(0)] cloned_logits = logits.clone() - cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf, device=logits.device, dtype=logits.dtype) + cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor( + -ch.inf, device=logits.device, dtype=logits.dtype + ) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return margins.sum() - def get_out_to_loss_grad(self, model, weights, buffers, batch: Iterable[Tensor]) -> Tensor: + def get_out_to_loss_grad( + self, model, weights, buffers, batch: Iterable[Tensor] + ) -> Tensor: input_ids, token_type_ids, attention_mask, labels = batch - kw_inputs = {'input_ids': input_ids, - 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask} - logits = ch.func.functional_call(model, - (weights, buffers), - args=(), - kwargs=kw_inputs) - ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels] + kw_inputs = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": attention_mask, + } + logits = ch.func.functional_call( + model, (weights, buffers), args=(), kwargs=kw_inputs + ) + ps = self.softmax(logits / self.loss_temperature)[ + ch.arange(logits.size(0)), labels + ] return (1 - ps).clone().detach().unsqueeze(-1) TASK_TO_MODELOUT = { - 'image_classification': ImageClassificationModelOutput, - 'clip': CLIPModelOutput, - 'text_classification': TextClassificationModelOutput, + "image_classification": ImageClassificationModelOutput, + "clip": CLIPModelOutput, + "text_classification": TextClassificationModelOutput, } diff --git a/trak/projectors.py b/trak/projectors.py index d736bfa..a833da9 100644 --- a/trak/projectors.py +++ b/trak/projectors.py @@ -1,30 +1,48 @@ +""" +Projectors are used to project gradients to a lower-dimensional space. This 1) allows +us to compute TRAK scores in a *much* more efficient manner, and 2) turns out to be +act as a useful regularizer (see Appendix E.1 in our paper). + +Here, we provide four implementations of the projector: +- :class:`NoOpProjector` (no-op) +- :class:`BasicSingleBlockProjector` (bare-bones, inefficient implementation) +- :class:`BasicProjector` (block-wise implementation) +- :class:`CudaProjector` (a fast implementation with a custom CUDA kernel) +""" from abc import ABC, abstractmethod from typing import Union from enum import Enum -from torch import Tensor import math +from torch import Tensor import torch + +from .utils import vectorize + + ch = torch class ProjectionType(str, Enum): - normal: str = 'normal' - rademacher: str = 'rademacher' + normal: str = "normal" + rademacher: str = "rademacher" class AbstractProjector(ABC): - """ Implementations of the Projector class must implement the + """Implementations of the Projector class must implement the :meth:`AbstractProjector.project` method, which takes in model gradients and returns """ + @abstractmethod - def __init__(self, - grad_dim: int, - proj_dim: int, - seed: int, - proj_type: Union[str, ProjectionType], - device: Union[str, torch.device]) -> None: - """ Initializes hyperparameters for the projection. + def __init__( + self, + grad_dim: int, + proj_dim: int, + seed: int, + proj_type: Union[str, ProjectionType], + device: Union[str, torch.device], + ) -> None: + """Initializes hyperparameters for the projection. Args: grad_dim (int): @@ -53,7 +71,7 @@ def __init__(self, @abstractmethod def project(self, grads: Tensor, model_id: int) -> Tensor: - """ Performs the random projection. Model ID is included + """Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID. @@ -64,7 +82,9 @@ def project(self, grads: Tensor, model_id: int) -> Tensor: Returns: Tensor: the projected gradients """ - ... + + def free_memory(self): + """Frees up memory used by the projector.""" class NoOpProjector(AbstractProjector): @@ -72,18 +92,21 @@ class NoOpProjector(AbstractProjector): A projector that returns the gradients as they are, i.e., implements :code:`projector.project(grad) = grad`. """ - def __init__(self, - grad_dim: int = 0, - proj_dim: int = 0, - seed: int = 0, - proj_type: Union[str, ProjectionType] = 'na', - device: Union[str, torch.device] = 'na', - *args, - **kwargs) -> None: + + def __init__( + self, + grad_dim: int = 0, + proj_dim: int = 0, + seed: int = 0, + proj_type: Union[str, ProjectionType] = "na", + device: Union[str, torch.device] = "cuda", + *args, + **kwargs, + ) -> None: super().__init__(grad_dim, proj_dim, seed, proj_type, device) def project(self, grads: Tensor, model_id: int) -> Tensor: - """ A no-op method. + """A no-op method. Args: grads (Tensor): a batch of gradients to be projected @@ -92,7 +115,11 @@ def project(self, grads: Tensor, model_id: int) -> Tensor: Returns: Tensor: the (non-)projected gradients """ - return grads + return vectorize(grads, device=self.device) + + def free_memory(self): + """A no-op method.""" + pass class BasicSingleBlockProjector(AbstractProjector): @@ -107,40 +134,69 @@ class BasicSingleBlockProjector(AbstractProjector): added this only for testing purposes), use instead the CudaProjector or BasicProjector. """ - def __init__(self, grad_dim: int, proj_dim: int, seed: int, proj_type: - ProjectionType, device, dtype=ch.float32, model_id=0, - *args, **kwargs) -> None: + + def __init__( + self, + grad_dim: int, + proj_dim: int, + seed: int, + proj_type: ProjectionType, + device, + dtype=ch.float32, + model_id=0, + *args, + **kwargs, + ) -> None: super().__init__(grad_dim, proj_dim, seed, proj_type, device) self.model_id = model_id self.proj_type = proj_type self.generator = ch.Generator(device=self.device) - self.generator = self.generator.manual_seed(self.seed + int(1e4) * self.model_id) + self.generator = self.generator.manual_seed( + self.seed + int(1e4) * self.model_id + ) self.dtype = dtype - self.proj_matrix = ch.empty(self.grad_dim, - self.proj_dim, - dtype=self.dtype, - device=self.device) + self.proj_matrix = ch.empty( + self.grad_dim, self.proj_dim, dtype=self.dtype, device=self.device + ) + + self.proj_matrix_available = True self.generate_sketch_matrix() # updates self.proj_matrix + def free_memory(self): + del self.proj_matrix + self.proj_matrix_available = False + def generate_sketch_matrix(self): - if self.proj_type == ProjectionType.normal or self.proj_type == 'normal': + if not self.proj_matrix_available: + self.proj_matrix = ch.empty( + self.grad_dim, self.proj_dim, dtype=self.dtype, device=self.device + ) + self.proj_matrix_available = True + + if self.proj_type == ProjectionType.normal or self.proj_type == "normal": self.proj_matrix.normal_(generator=self.generator) - elif self.proj_type == ProjectionType.rademacher or self.proj_type == 'rademacher': + elif ( + self.proj_type == ProjectionType.rademacher + or self.proj_type == "rademacher" + ): self.proj_matrix.bernoulli_(p=0.5, generator=self.generator) # going from Bernoulli {0, 1} to Rademacher {-1, 1} - self.proj_matrix *= 2. - self.proj_matrix -= 1. + self.proj_matrix *= 2.0 + self.proj_matrix -= 1.0 else: - raise KeyError(f'Projection type {self.proj_type} not recognized.') + raise KeyError(f"Projection type {self.proj_type} not recognized.") def project(self, grads: Tensor, model_id: int) -> Tensor: + grads = vectorize(grads, device=self.device) grads = grads.to(dtype=self.dtype) if model_id != self.model_id: self.model_id = model_id - self.generator = self.generator.manual_seed(self.seed + int(1e4) * self.model_id) + self.generator = self.generator.manual_seed( + self.seed + int(1e4) * self.model_id + ) self.generate_sketch_matrix() # updates self.proj_matrix return grads @ self.proj_matrix @@ -158,15 +214,20 @@ class BasicProjector(AbstractProjector): a CUDA-enabled device with compute capability >=7.0 (see https://developer.nvidia.com/cuda-gpus). """ - def __init__(self, grad_dim: int, - proj_dim: int, - seed: int, - proj_type: ProjectionType, - device: torch.device, - block_size: int = 100, - dtype: torch.dtype = ch.float32, - model_id=0, - *args, **kwargs) -> None: + + def __init__( + self, + grad_dim: int, + proj_dim: int, + seed: int, + proj_type: ProjectionType, + device: torch.device, + block_size: int = 100, + dtype: torch.dtype = ch.float32, + model_id=0, + *args, + **kwargs, + ) -> None: super().__init__(grad_dim, proj_dim, seed, proj_type, device) self.block_size = min(self.proj_dim, block_size) @@ -175,16 +236,21 @@ def __init__(self, grad_dim: int, self.proj_type = proj_type self.model_id = model_id - self.proj_matrix = ch.empty(self.grad_dim, - self.block_size, - dtype=self.dtype, - device=self.device) + self.proj_matrix = ch.empty( + self.grad_dim, self.block_size, dtype=self.dtype, device=self.device + ) + + self.proj_matrix_available = True self.generator = ch.Generator(device=self.device) self.get_generator_states() self.generate_sketch_matrix(self.generator_states[0]) + def free_memory(self): + del self.proj_matrix + self.proj_matrix_available = False + def get_generator_states(self): self.generator_states = [] self.seeds = [] @@ -197,20 +263,31 @@ def get_generator_states(self): self.generator_states.append(self.generator.get_state()) def generate_sketch_matrix(self, generator_state): + if not self.proj_matrix_available: + self.proj_matrix = ch.empty( + self.grad_dim, self.block_size, dtype=self.dtype, device=self.device + ) + self.proj_matrix_available = True + self.generator.set_state(generator_state) - if self.proj_type == ProjectionType.normal or self.proj_type == 'normal': + if self.proj_type == ProjectionType.normal or self.proj_type == "normal": self.proj_matrix.normal_(generator=self.generator) - elif self.proj_type == ProjectionType.rademacher or self.proj_type == 'rademacher': + elif ( + self.proj_type == ProjectionType.rademacher + or self.proj_type == "rademacher" + ): self.proj_matrix.bernoulli_(p=0.5, generator=self.generator) - self.proj_matrix *= 2. - self.proj_matrix -= 1. + self.proj_matrix *= 2.0 + self.proj_matrix -= 1.0 else: - raise KeyError(f'Projection type {self.proj_type} not recognized.') + raise KeyError(f"Projection type {self.proj_type} not recognized.") def project(self, grads: Tensor, model_id: int) -> Tensor: + grads = vectorize(grads, device=self.device) grads = grads.to(dtype=self.dtype) - sketch = ch.zeros(size=(grads.size(0), self.proj_dim), - dtype=self.dtype, device=self.device) + sketch = ch.zeros( + size=(grads.size(0), self.proj_dim), dtype=self.dtype, device=self.device + ) if model_id != self.model_id: self.model_id = model_id @@ -226,7 +303,9 @@ def project(self, grads: Tensor, model_id: int) -> Tensor: st = ind * self.block_size ed = min((ind + 1) * self.block_size, self.proj_dim) - sketch[:, st:ed] = grads.type(self.dtype) @ self.proj_matrix[:, :(ed - st)] + sketch[:, st:ed] = ( + grads.type(self.dtype) @ self.proj_matrix[:, : (ed - st)] + ) return sketch.type(grads.dtype) @@ -235,8 +314,18 @@ class CudaProjector(AbstractProjector): A performant implementation of the projection for CUDA with compute capability >= 7.0. """ - def __init__(self, grad_dim: int, proj_dim: int, seed: int, proj_type: - ProjectionType, device, max_batch_size: int, *args, **kwargs) -> None: + + def __init__( + self, + grad_dim: int, + proj_dim: int, + seed: int, + proj_type: ProjectionType, + device, + max_batch_size: int, + *args, + **kwargs, + ) -> None: """ Args: @@ -269,7 +358,7 @@ def __init__(self, grad_dim: int, proj_dim: int, seed: int, proj_type: if isinstance(device, str): device = ch.device(device) - if device.type != 'cuda': + if device.type != "cuda": err = "CudaProjector only works on a CUDA device; Either switch to a CUDA device, or use the BasicProjector" raise ValueError(err) @@ -277,14 +366,24 @@ def __init__(self, grad_dim: int, proj_dim: int, seed: int, proj_type: try: import fast_jl + # test run to catch at init time if projection goes through - fast_jl.project_rademacher_8(ch.zeros(8, 1_000, device='cuda'), 512, 0, self.num_sms) + fast_jl.project_rademacher_8( + ch.zeros(8, 1_000, device="cuda"), 512, 0, self.num_sms + ) except ImportError: err = "You should make sure to install the CUDA projector for traker (called fast_jl).\ See the installation FAQs for more details." raise ModuleNotFoundError(err) - def project(self, grads: Tensor, model_id: int) -> Tensor: + def project( + self, + grads: Union[dict, Tensor], + model_id: int, + is_grads_dict: bool = True, + ) -> Tensor: + if is_grads_dict: + grads = vectorize(grads, device=self.device) batch_size = grads.shape[0] effective_batch_size = 32 @@ -297,15 +396,118 @@ def project(self, grads: Tensor, model_id: int) -> Tensor: function_name = f"project_{self.proj_type.value}_{effective_batch_size}" import fast_jl + fn = getattr(fast_jl, function_name) try: - result = fn(grads, self.proj_dim, self.seed + int(1e4) * model_id, self.num_sms) + result = fn( + grads, self.proj_dim, self.seed + int(1e4) * model_id, self.num_sms + ) except RuntimeError as e: - if str(e) == 'CUDA error: too many resources requested for launch\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n': # noqa: E501 + if "CUDA error: too many resources requested for launch" in str(e): # provide a more helpful error message - raise RuntimeError('The batch size of the CudaProjector is too large for your GPU. Reduce it by using the proj_max_batch_size argument of the TRAKer.\nOriginal error.') # noqa: E501 + raise RuntimeError( + ( + "The batch size of the CudaProjector is too large for your GPU. " + "Reduce it by using the proj_max_batch_size argument of the TRAKer.\nOriginal error:" + ) + ) else: raise e return result + + def free_memory(self): + """A no-op method.""" + pass + + +class ChunkedCudaProjector: + def __init__( + self, + projector_per_chunk: list, + max_chunk_size: int, + params_per_chunk: list, + feat_bs: int, + device: torch.device, + dtype: torch.dtype, + ): + self.projector_per_chunk = projector_per_chunk + self.proj_dim = self.projector_per_chunk[0].proj_dim + self.proj_type = self.projector_per_chunk[0].proj_type + self.params_per_chunk = params_per_chunk + + self.max_chunk_size = max_chunk_size + self.feat_bs = feat_bs + self.device = device + self.dtype = dtype + self.input_allocated = False + + def allocate_input(self): + if self.input_allocated: + return + + self.ch_input = ch.zeros( + size=(self.feat_bs, self.max_chunk_size), + device=self.device, + dtype=self.dtype, + ) + + self.input_allocated = True + + def free_memory(self): + if not self.input_allocated: + return + + del self.ch_input + self.input_allocated = False + + def project(self, grads, model_id): + self.allocate_input() + ch_output = ch.zeros( + size=(self.feat_bs, self.proj_dim), device=self.device, dtype=self.dtype + ) + pointer = 0 + # iterate over params, keep a counter of params so far, and when prev + # chunk reaches max_chunk_size, project and accumulate + projector_index = 0 + for i, p in enumerate(grads.values()): + if len(p.shape) < 2: + p_flat = p.data.unsqueeze(-1) + else: + p_flat = p.data.flatten(start_dim=1) + + param_size = p_flat.size(1) + if pointer + param_size > self.max_chunk_size: + # fill remaining entries with 0 + assert pointer == self.params_per_chunk[projector_index] + # project and accumulate + ch_output.add_( + self.projector_per_chunk[projector_index].project( + self.ch_input[:, :pointer].contiguous(), + model_id=model_id, + is_grads_dict=False, + ) + ) + # reset counter + pointer = 0 + projector_index += 1 + + # continue accumulation + actual_bs = min(self.ch_input.size(0), p_flat.size(0)) + self.ch_input[:actual_bs, pointer : pointer + param_size].copy_(p_flat) + pointer += param_size + + # at the end, we need to project remaining items + # fill remaining entries with 0 + assert pointer == self.params_per_chunk[projector_index] + # project and accumulate + ch_output[:actual_bs].add_( + self.projector_per_chunk[projector_index].project( + self.ch_input[:actual_bs, :pointer].contiguous(), + model_id=model_id, + is_grads_dict=False, + ) + ) + + return ch_output[:actual_bs] diff --git a/trak/savers.py b/trak/savers.py index 94ff052..dd96881 100644 --- a/trak/savers.py +++ b/trak/savers.py @@ -1,3 +1,11 @@ +""" +This module contains classes that save TRAK results, intermediate values, and +metadata to disk. The :code:`AbstractSaver` class defines the interface for +savers. Then, we provide one implementation: +- :class:`MmapSaver`: A saver that uses memory-mapped numpy arrays. This makes + loading and saving small chunks of data (e.g.) during featurizing feasible + without loading the entire file into memory. +""" from abc import ABC, abstractmethod from typing import Optional, Iterable, Union from pathlib import Path @@ -7,6 +15,7 @@ import numpy as np from numpy.lib.format import open_memmap import torch + ch = torch @@ -21,14 +30,17 @@ class AbstractSaver(ABC): used for the dimensionality reduction step of TRAK (Johnson-Lindenstrauss projection). """ + @abstractmethod - def __init__(self, - save_dir: Union[Path, str], - metadata: Iterable, - load_from_save_dir: bool, - logging_level: int, - use_half_precision: bool) -> None: - """ Creates the save directory if it doesn't already exist. + def __init__( + self, + save_dir: Union[Path, str], + metadata: Iterable, + load_from_save_dir: bool, + logging_level: int, + use_half_precision: bool, + ) -> None: + """Creates the save directory if it doesn't already exist. If the save directory already exists, it validates that the current TRAKer class has the same hyperparameters (metadata) as the one specified in the save directory. Next, this method loads any existing @@ -57,85 +69,95 @@ def __init__(self, self.use_half_precision = use_half_precision os.makedirs(self.save_dir, exist_ok=True) - os.makedirs(self.save_dir.joinpath('scores'), exist_ok=True) + os.makedirs(self.save_dir.joinpath("scores"), exist_ok=True) - self.logger = logging.getLogger('STORE') + self.logger = logging.getLogger("STORE") self.logger.setLevel(logging_level) # init TRAKer metadata - self.metadata_file = self.save_dir.joinpath('metadata.json') + self.metadata_file = self.save_dir.joinpath("metadata.json") if os.path.exists(self.metadata_file) and self.load_from_save_dir: - with open(self.metadata_file, 'r') as f: + with open(self.metadata_file, "r") as f: existsing_metadata = json.load(f) - existing_jl_dim = int(existsing_metadata['JL dimension']) - assert self.metadata['JL dimension'] == existing_jl_dim,\ - f"In {self.save_dir} there are models using JL dimension {existing_jl_dim},\n\ + existing_jl_dim = int(existsing_metadata["JL dimension"]) + assert ( + self.metadata["JL dimension"] == existing_jl_dim + ), f"In {self.save_dir} there are models using JL dimension {existing_jl_dim},\n\ and this TRAKer instance uses JL dimension {self.metadata['JL dimension']}." - existing_matrix_type = existsing_metadata['JL matrix type'] - assert self.metadata['JL matrix type'] == existing_matrix_type,\ - f"In {self.save_dir} there are models using a {existing_matrix_type} JL matrix,\n\ + existing_matrix_type = existsing_metadata["JL matrix type"] + assert ( + self.metadata["JL matrix type"] == existing_matrix_type + ), f"In {self.save_dir} there are models using a {existing_matrix_type} JL matrix,\n\ and this TRAKer instance uses a {self.metadata['JL matrix type']} JL matrix." - assert self.metadata['train set size'] == existsing_metadata['train set size'],\ - f"In {self.save_dir} there are models TRAKing\n\ + assert ( + self.metadata["train set size"] == existsing_metadata["train set size"] + ), f"In {self.save_dir} there are models TRAKing\n\ {existsing_metadata['train set size']} examples, and in this TRAKer instance\n\ there are {self.metadata['train set size']} examples." elif self.load_from_save_dir: - with open(self.metadata_file, 'w') as f: + with open(self.metadata_file, "w") as f: json.dump(self.metadata, f) self.model_ids = {} self.experiments = {} - self.experiments_file = self.save_dir.joinpath('experiments.json') + self.experiments_file = self.save_dir.joinpath("experiments.json") if self.load_from_save_dir: # check if there are existing model ids in the save_dir - self.model_ids_files = self.save_dir.rglob('id_*.json') + self.model_ids_files = self.save_dir.rglob("id_*.json") for existing_model_id_file in self.model_ids_files: - with open(existing_model_id_file, 'r') as f: + with open(existing_model_id_file, "r") as f: existing_id = json.load(f) - existing_id = {int(model_id): metadata - for model_id, metadata in existing_id.items()} + existing_id = { + int(model_id): metadata + for model_id, metadata in existing_id.items() + } self.model_ids.update(existing_id) if os.path.isfile(self.experiments_file): - with open(self.experiments_file, 'r') as f: + with open(self.experiments_file, "r") as f: self.experiments.update(json.load(f)) else: - with open(self.experiments_file, 'w') as f: + with open(self.experiments_file, "w") as f: json.dump({}, f) existing_ids = list(self.model_ids.keys()) if len(existing_ids) > 0: - self.logger.info(f'Existing model IDs in {self.save_dir}: {sorted(existing_ids)}') - ids_finalized = sorted(list([id for id, v in self.model_ids.items() - if v['is_finalized'] == 1])) + self.logger.info( + f"Existing model IDs in {self.save_dir}: {sorted(existing_ids)}" + ) + ids_finalized = sorted( + list([id for id, v in self.model_ids.items() if v["is_finalized"] == 1]) + ) if len(ids_finalized) > 0: - self.logger.info(f'Model IDs that have been finalized: {ids_finalized}') + self.logger.info(f"Model IDs that have been finalized: {ids_finalized}") else: - self.logger.info(f'No model IDs in {self.save_dir} have been finalized.') + self.logger.info( + f"No model IDs in {self.save_dir} have been finalized." + ) else: - self.logger.info(f'No existing model IDs in {self.save_dir}.') + self.logger.info(f"No existing model IDs in {self.save_dir}.") if len(list(self.experiments.keys())) > 0: - self.logger.info('Existing TRAK scores:') + self.logger.info("Existing TRAK scores:") for exp_name, values in self.experiments.items(): self.logger.info(f"{exp_name}: {values['scores_path']}") else: - self.logger.info(f'No existing TRAK scores in {self.save_dir}.') + self.logger.info(f"No existing TRAK scores in {self.save_dir}.") self.current_model_id = None self.current_store = { - 'grads': None, - 'out_to_loss': None, - 'features': None, + "grads": None, + "out_to_loss": None, + "features": None, } @abstractmethod def register_model_id(self, model_id: int) -> None: - """ Create metadata for a new model ID (checkpoint). + """Create metadata for a new model ID (checkpoint). Args: model_id (int): @@ -146,7 +168,7 @@ def register_model_id(self, model_id: int) -> None: @abstractmethod def serialize_current_model_id_metadata(self) -> None: - """ Write to disk / commit any updates to the metadata associated + """Write to disk / commit any updates to the metadata associated to the current model ID """ @@ -154,7 +176,7 @@ def serialize_current_model_id_metadata(self) -> None: @abstractmethod def init_store(self, model_id: int) -> None: - """ Initializes store for a given model ID (checkpoint). + """Initializes store for a given model ID (checkpoint). Args: model_id (int): @@ -164,7 +186,7 @@ def init_store(self, model_id: int) -> None: @abstractmethod def init_experiment(self, model_id: int) -> None: - """ Initializes store for a given experiment & model ID (checkpoint). + """Initializes store for a given experiment & model ID (checkpoint). Args: model_id (int): @@ -174,7 +196,7 @@ def init_experiment(self, model_id: int) -> None: @abstractmethod def load_current_store(self, model_id: int) -> None: - """ Populates the self.current_store attributes with data for the + """Populates the self.current_store attributes with data for the given model ID (checkpoint). Args: @@ -186,7 +208,7 @@ def load_current_store(self, model_id: int) -> None: @abstractmethod def save_scores(self, exp_name: str) -> None: - """ Saves scores for a given experiment name + """Saves scores for a given experiment name Args: exp_name (str): @@ -197,7 +219,7 @@ def save_scores(self, exp_name: str) -> None: @abstractmethod def del_grads(self, model_id: int, target: bool) -> None: - """ Delete the intermediate values (gradients) for a given model id + """Delete the intermediate values (gradients) for a given model id Args: model_id (int): @@ -211,30 +233,42 @@ def del_grads(self, model_id: int, target: bool) -> None: class ModelIDException(Exception): - """ A minimal custom exception for errors related to model IDs """ + """A minimal custom exception for errors related to model IDs""" + pass class MmapSaver(AbstractSaver): - """ A saver that uses memory-mapped numpy arrays. This makes small reads and + """A saver that uses memory-mapped numpy arrays. This makes small reads and writes (e.g.) during featurizing feasible without loading the entire file into memory. """ - def __init__(self, save_dir, metadata, train_set_size, proj_dim, - load_from_save_dir, logging_level, use_half_precision) -> None: - super().__init__(save_dir=save_dir, - metadata=metadata, - load_from_save_dir=load_from_save_dir, - logging_level=logging_level, - use_half_precision=use_half_precision) + + def __init__( + self, + save_dir, + metadata, + train_set_size, + proj_dim, + load_from_save_dir, + logging_level, + use_half_precision, + ) -> None: + super().__init__( + save_dir=save_dir, + metadata=metadata, + load_from_save_dir=load_from_save_dir, + logging_level=logging_level, + use_half_precision=use_half_precision, + ) self.train_set_size = train_set_size self.proj_dim = proj_dim - def register_model_id(self, - model_id: int, - _allow_featurizing_already_registered: bool) -> None: - """ This method + def register_model_id( + self, model_id: int, _allow_featurizing_already_registered: bool + ) -> None: + """This method 1) checks if the model ID already exists in the save dir 2) if yes, it raises an error since model IDs must be unique 3) if not, it creates a metadata file for it and initalizes store mmaps @@ -250,86 +284,103 @@ def register_model_id(self, """ self.current_model_id = model_id - if self.current_model_id in self.model_ids.keys() and (not _allow_featurizing_already_registered): - err_msg = f'model id {self.current_model_id} is already registered. Check {self.save_dir}' + if self.current_model_id in self.model_ids.keys() and ( + not _allow_featurizing_already_registered + ): + err_msg = f"model id {self.current_model_id} is already registered. Check {self.save_dir}" raise ModelIDException(err_msg) - self.model_ids[self.current_model_id] = {'is_featurized': 0, 'is_finalized': 0} + self.model_ids[self.current_model_id] = {"is_featurized": 0, "is_finalized": 0} self.init_store(self.current_model_id) self.serialize_current_model_id_metadata(already_exists=False) def serialize_current_model_id_metadata(self, already_exists=True) -> None: - is_featurized = int(self.current_store['is_featurized'].sum() == self.train_set_size) + is_featurized = int( + self.current_store["is_featurized"].sum() == self.train_set_size + ) # update the metadata JSON file content = { - self.current_model_id: - { - 'is_featurized': is_featurized, - 'is_finalized': self.model_ids[self.current_model_id]['is_finalized'] - } + self.current_model_id: { + "is_featurized": is_featurized, + "is_finalized": self.model_ids[self.current_model_id]["is_finalized"], } + } # update the metadata dict within the class instance - self.model_ids[self.current_model_id]['is_featurized'] = is_featurized + self.model_ids[self.current_model_id]["is_featurized"] = is_featurized if (is_featurized == 1) or not already_exists: - with open(self.save_dir.joinpath(f'id_{self.current_model_id}.json'), 'w') as f: + with open( + self.save_dir.joinpath(f"id_{self.current_model_id}.json"), "w" + ) as f: json.dump(content, f) def init_store(self, model_id) -> None: prefix = self.save_dir.joinpath(str(model_id)) if os.path.exists(prefix): - self.logger.info(f'Model ID folder {prefix} already exists') + self.logger.info(f"Model ID folder {prefix} already exists") os.makedirs(prefix, exist_ok=True) featurized_so_far = np.zeros(shape=(self.train_set_size,), dtype=np.int32) - ft = self._load(prefix.joinpath('_is_featurized.mmap'), - shape=(self.train_set_size, ), - mode='w+', - dtype=np.int32) + ft = self._load( + prefix.joinpath("_is_featurized.mmap"), + shape=(self.train_set_size,), + mode="w+", + dtype=np.int32, + ) ft[:] = featurized_so_far[:] ft.flush() - self.load_current_store(model_id, mode='w+') + self.load_current_store(model_id, mode="w+") def init_experiment(self, exp_name, num_targets, model_id) -> None: prefix = self.save_dir.joinpath(str(model_id)) if not os.path.exists(prefix): - raise ModelIDException(f'model ID folder {prefix} does not exist,\n\ - cannot start scoring') - self.experiments[exp_name] = {'num_targets': num_targets, - 'scores_path': str(self.save_dir.joinpath(f'scores/{exp_name}.mmap')), - 'scores_finalized': 0, - } + raise ModelIDException( + f"model ID folder {prefix} does not exist,\n\ + cannot start scoring" + ) + self.experiments[exp_name] = { + "num_targets": num_targets, + "scores_path": str(self.save_dir.joinpath(f"scores/{exp_name}.mmap")), + "scores_finalized": 0, + } # update experiments.json - with open(self.experiments_file, 'r') as fp: + with open(self.experiments_file, "r") as fp: exp_f = json.load(fp) exp_f[exp_name] = self.experiments[exp_name] - with open(self.experiments_file, 'w') as fp: + with open(self.experiments_file, "w") as fp: json.dump(exp_f, fp) - if os.path.exists(prefix.joinpath(f'{exp_name}_grads.mmap')): - mode = 'r+' + if os.path.exists(prefix.joinpath(f"{exp_name}_grads.mmap")): + mode = "r+" else: - mode = 'w+' - self.load_current_store(model_id=model_id, exp_name=exp_name, - exp_num_targets=num_targets, mode=mode) + mode = "w+" + self.load_current_store( + model_id=model_id, exp_name=exp_name, exp_num_targets=num_targets, mode=mode + ) def _load(self, fname, shape, mode, dtype=None): - if mode == 'w+': - self.logger.debug(f'Creating {fname}.') + if mode == "w+": + self.logger.debug(f"Creating {fname}.") else: - self.logger.debug(f'Loading {fname}.') + self.logger.debug(f"Loading {fname}.") if dtype is None: dtype = np.float16 if self.use_half_precision else np.float32 - return open_memmap(filename=fname, mode=mode, shape=shape, dtype=dtype) - - def load_current_store(self, - model_id: int, - exp_name: Optional[str] = None, - exp_num_targets: Optional[int] = -1, - mode: Optional[str] = 'r+') -> None: - """ This method uses numpy memmaps for serializing the TRAK results and + try: + return open_memmap(filename=fname, mode=mode, shape=shape, dtype=dtype) + except OSError: + self.logger.info(f"{fname} does not exist, skipping.") + return None + + def load_current_store( + self, + model_id: int, + exp_name: Optional[str] = None, + exp_num_targets: Optional[int] = -1, + mode: Optional[str] = "r+", + ) -> None: + """This method uses numpy memmaps for serializing the TRAK results and intermediate values. Args: @@ -352,25 +403,39 @@ def load_current_store(self, if exp_name is None: to_load = { - 'grads': (prefix.joinpath('grads.mmap'), - (self.train_set_size, self.proj_dim), - None), - 'out_to_loss': (prefix.joinpath('out_to_loss.mmap'), - (self.train_set_size, 1), - None), - 'features': (prefix.joinpath('features.mmap'), - (self.train_set_size, self.proj_dim), - None), - 'is_featurized': (prefix.joinpath('_is_featurized.mmap'), - (self.train_set_size, 1), - np.int32), + "grads": ( + prefix.joinpath("grads.mmap"), + (self.train_set_size, self.proj_dim), + None, + ), + "out_to_loss": ( + prefix.joinpath("out_to_loss.mmap"), + (self.train_set_size, 1), + None, + ), + "features": ( + prefix.joinpath("features.mmap"), + (self.train_set_size, self.proj_dim), + None, + ), + "is_featurized": ( + prefix.joinpath("_is_featurized.mmap"), + (self.train_set_size, 1), + np.int32, + ), } else: to_load = { - f'{exp_name}_grads': (prefix.joinpath(f'{exp_name}_grads.mmap'), - (exp_num_targets, self.proj_dim), None), - f'{exp_name}_scores': (self.save_dir.joinpath(f'scores/{exp_name}.mmap'), - (self.train_set_size, exp_num_targets), None), + f"{exp_name}_grads": ( + prefix.joinpath(f"{exp_name}_grads.mmap"), + (exp_num_targets, self.proj_dim), + None, + ), + f"{exp_name}_scores": ( + self.save_dir.joinpath(f"scores/{exp_name}.mmap"), + (self.train_set_size, exp_num_targets), + None, + ), } for name, (path, shape, dtype) in to_load.items(): @@ -378,13 +443,15 @@ def load_current_store(self, def save_scores(self, exp_name): assert self.current_experiment_name == exp_name - prefix = self.save_dir.joinpath('scores') - self.logger.info(f'Saving scores in {prefix}/{exp_name}.mmap') - self.current_store[f'{exp_name}_scores'].flush() - self.experiments[exp_name]['scores_finalized'] = True + prefix = self.save_dir.joinpath("scores") + self.logger.info(f"Saving scores in {prefix}/{exp_name}.mmap") + self.current_store[f"{exp_name}_scores"].flush() + self.experiments[exp_name]["scores_finalized"] = 1 + with open(self.experiments_file, "w") as fp: + json.dump(self.experiments, fp) def del_grads(self, model_id): - grads_file = self.save_dir.joinpath(str(model_id)).joinpath('grads.mmap') + grads_file = self.save_dir.joinpath(str(model_id)).joinpath("grads.mmap") # delete grads memmap grads_file.unlink() diff --git a/trak/score_computers.py b/trak/score_computers.py index 56a51c3..120e163 100644 --- a/trak/score_computers.py +++ b/trak/score_computers.py @@ -1,6 +1,22 @@ +""" +Computing scores for the TRAK algorithm from pre-computed projected gradients +involves a number of matrix multiplications. This module contains classes that +perform these operations. The :code:`AbstractScoreComputer` class defines the +interface for score computers. Then, we provide two implementations: +- :class:`BasicSingleBlockScoreComputer`: A bare-bones implementation, mostly for + testing purposes. +- :class:`BasicScoreComputer`: A more sophisticated implementation that does + block-wise matrix multiplications to avoid OOM errors. + +""" from abc import ABC, abstractmethod +import logging from torch import Tensor -import torch as ch +import torch + +from .utils import get_matrix_mult + +ch = torch class AbstractScoreComputer(ABC): @@ -11,6 +27,7 @@ class AbstractScoreComputer(ABC): - :code:`get_x_xtx_inv` - :code:`get_scores` """ + @abstractmethod def __init__(self, dtype, device) -> None: self.dtype = dtype @@ -18,25 +35,68 @@ def __init__(self, dtype, device) -> None: @abstractmethod def get_xtx(self, grads: Tensor) -> Tensor: - ... + """Computes :math:`X^\top X`, where :math:`X` is the matrix of projected + gradients. Here, the shape of :math:`X` is :code:`(n, p)`, where + :math:`n` is the number of training examples and :math:`p` is the + dimension of the projection. + + + Args: + grads (Tensor): projected gradients of shape :code:`(n, p)`. + + Returns: + Tensor: :math:`X^\top X` of shape :code:`(p, p)`. + """ @abstractmethod def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor: - ... + """Computes :math:`X(X^\top X)^{-1}`, where :math:`X` is the matrix of + projected gradients. Here, the shape of :math:`X` is :code:`(n, p)`, + where :math:`n` is the number of training examples and :math:`p` is the + dimension of the projection. This function takes as input the + pre-computed :math:`X^\top X` matrix, which is computed by the + :code:`get_xtx` method. + + Args: + grads (Tensor): projected gradients :math:`X` of shape :code:`(n, p)`. + xtx (Tensor): :math:`X^\top X` of shape :code:`(p, p)`. + + Returns: + Tensor: :math:`X(X^\top X)^{-1}` of shape :code:`(n, p)`. + """ @abstractmethod - def get_scores(self, features: Tensor, target_grads: Tensor) -> Tensor: - ... + def get_scores( + self, features: Tensor, target_grads: Tensor, accumulator: Tensor + ) -> None: + """Computes the scores for a given set of features and target gradients. + In particular, this function takes in a matrix of features + :math:`\Phi=X(X^\top X)^{-1}`, computed by the :code:`get_x_xtx_inv` + method, and a matrix of target (projected) gradients :math:`X_{target}`. + Then, it computes the scores as :math:`\Phi X_{target}^\top`. The + resulting matrix has shape :code:`(n, m)`, where :math:`n` is the number + of training examples and :math:`m` is the number of target examples. + + The :code:`accumulator` argument is used to store the result of the + computation. This is useful when computing scores for multiple model + checkpoints, as it allows us to re-use the same memory for the score + matrix. + + Args: + features (Tensor): features :math:`\Phi` of shape :code:`(n, p)`. + target_grads (Tensor): + target projected gradients :math:`X_{target}` of shape + :code:`(m, p)`. + accumulator (Tensor): accumulator of shape :code:`(n, m)`. + """ class BasicSingleBlockScoreComputer(AbstractScoreComputer): - """ A bare-bones implementation of :code:`ScoreComputer` that will likely + """A bare-bones implementation of :code:`ScoreComputer` that will likely OOM for almost all applications. Here for testing purposes only. Unless you have a good reason not to, you should use :func:`BasicScoreComputer` instead. """ - def __init__(self, dtype, device) -> None: - super().__init__(dtype, device) def get_xtx(self, grads: Tensor) -> Tensor: return grads.T @ grads @@ -45,27 +105,44 @@ def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor: # torch.linalg.inv does not support float16 return grads @ ch.linalg.inv(xtx.float()).to(self.dtype) - def get_scores(self, features: Tensor, target_grads: Tensor) -> Tensor: - return features @ target_grads.T + def get_scores( + self, features: Tensor, target_grads: Tensor, accumulator: Tensor + ) -> None: + accumulator += (features @ target_grads.T).detach().cpu() class BasicScoreComputer(AbstractScoreComputer): - """ An implementation of :code:`ScoreComputer` that computes matmuls in a + """An implementation of :code:`ScoreComputer` that computes matmuls in a block-wise manner. """ - def __init__(self, dtype, device, CUDA_MAX_DIM_SIZE: int = 100_000) -> None: + + def __init__( + self, + dtype: torch.dtype, + device: torch.device, + CUDA_MAX_DIM_SIZE: int = 20_000, + logging_level=logging.INFO, + ) -> None: """ Args: - device (Union[str, torch.device]): torch device to do matmuls on - CUDA_MAX_DIM_SIZE (int, optional): Size of block for block-wise - matmuls. Defaults to 100_000. + dtype (torch.dtype): + device (Union[str, torch.device]): + torch device to do matmuls on + CUDA_MAX_DIM_SIZE (int, optional): + Size of block for block-wise matmuls. Defaults to 100_000. + logging_level (logging level, optional): + Logging level for the logger. Defaults to logging.info. """ super().__init__(dtype, device) self.CUDA_MAX_DIM_SIZE = CUDA_MAX_DIM_SIZE + self.logger = logging.getLogger("ScoreComputer") + self.logger.setLevel(logging_level) def get_xtx(self, grads: Tensor) -> Tensor: self.proj_dim = grads.shape[1] - result = ch.zeros(self.proj_dim, self.proj_dim, dtype=self.dtype, device=self.device) + result = ch.zeros( + self.proj_dim, self.proj_dim, dtype=self.dtype, device=self.device + ) blocks = ch.split(grads, split_size_or_sections=self.CUDA_MAX_DIM_SIZE, dim=0) for block in blocks: @@ -82,26 +159,23 @@ def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor: xtx_inv = xtx_inv.to(self.dtype) - result = ch.empty(grads.shape[0], xtx_inv.shape[1], dtype=self.dtype, device=self.device) + result = ch.empty( + grads.shape[0], xtx_inv.shape[1], dtype=self.dtype, device=self.device + ) for i, block in enumerate(blocks): start = i * self.CUDA_MAX_DIM_SIZE end = min(grads.shape[0], (i + 1) * self.CUDA_MAX_DIM_SIZE) - result[start: end] = (block.to(self.device) @ xtx_inv) + result[start:end] = block.to(self.device) @ xtx_inv return result - def get_scores(self, features: Tensor, target_grads: Tensor) -> Tensor: + def get_scores( + self, features: Tensor, target_grads: Tensor, accumulator: Tensor + ) -> Tensor: train_dim = features.shape[0] target_dim = target_grads.shape[0] - if target_dim < self.CUDA_MAX_DIM_SIZE: - return features @ target_grads.T + self.logger.debug(f"{train_dim=}, {target_dim=}") - result = ch.empty(train_dim, target_dim, dtype=self.dtype, device=self.device) - blocks = ch.split(target_grads, split_size_or_sections=self.CUDA_MAX_DIM_SIZE, dim=0) - - for i, block in enumerate(blocks): - start = i * self.CUDA_MAX_DIM_SIZE - end = min(target_grads.shape[0], (i + 1) * self.CUDA_MAX_DIM_SIZE) - result[:, start: end] = features @ block.T - - return result + accumulator += ( + get_matrix_mult(features=features, target_grads=target_grads).detach().cpu() + ) diff --git a/trak/traker.py b/trak/traker.py index 0ec22f4..58f98f1 100644 --- a/trak/traker.py +++ b/trak/traker.py @@ -1,10 +1,31 @@ +""" +This module contains the main :class:`.TRAKer` class, which is the front-facing +class for TRAK. See the `README `_ and `docs +`_ for example usage. + +In short, methods of the :class:`.TRAKer` class are used to compute TRAK scores +for a set of model checkpoints, a set of target samples, and a set of train samples. +This is done in two stages: +- Featurizing. :func:`.TRAKer.featurize` and :func:`.TRAKer.finalize_features` + are used to compute the TRAK features for a set of model checkpoints and a set + of train samples. +- Scoring. :func:`.TRAKer.start_scoring_checkpoint`, :func:`.TRAKer.score`, and + :func:`.TRAKer.finalize_scores` are used to compute the TRAK scores for a set + of target samples, given the TRAK features computed in the previous step. + +""" from .modelout_functions import AbstractModelOutput, TASK_TO_MODELOUT -from .projectors import ProjectionType, AbstractProjector, CudaProjector, BasicProjector -from .gradient_computers import FunctionalGradientComputer, \ - AbstractGradientComputer +from .projectors import ( + ProjectionType, + AbstractProjector, + CudaProjector, + BasicProjector, + ChunkedCudaProjector, +) +from .gradient_computers import FunctionalGradientComputer, AbstractGradientComputer from .score_computers import AbstractScoreComputer, BasicScoreComputer from .savers import AbstractSaver, MmapSaver, ModelIDException -from .utils import get_num_params +from .utils import get_num_params, get_parameter_chunk_sizes from typing import Iterable, Optional, Union from pathlib import Path @@ -14,32 +35,36 @@ import logging import numpy as np import torch + ch = torch -class TRAKer(): - """ The main front-facing class for TRAK. See the `README +class TRAKer: + """The main front-facing class for TRAK. See the `README `_ and `docs `_ for example usage. """ - def __init__(self, - model: torch.nn.Module, - task: Union[AbstractModelOutput, str], - train_set_size: int, - save_dir: str = './trak_results', - load_from_save_dir: bool = True, - device: Union[str, torch.device] = 'cuda', - gradient_computer: AbstractGradientComputer = FunctionalGradientComputer, - projector: Optional[AbstractProjector] = None, - saver: Optional[AbstractSaver] = None, - score_computer: Optional[AbstractScoreComputer] = None, - proj_dim: int = 2048, - logging_level=logging.INFO, - use_half_precision: bool = True, - proj_max_batch_size: int = 32, - projector_seed: int = 0, - ) -> None: + + def __init__( + self, + model: torch.nn.Module, + task: Union[AbstractModelOutput, str], + train_set_size: int, + save_dir: str = "./trak_results", + load_from_save_dir: bool = True, + device: Union[str, torch.device] = "cuda", + gradient_computer: AbstractGradientComputer = FunctionalGradientComputer, + projector: Optional[AbstractProjector] = None, + saver: Optional[AbstractSaver] = None, + score_computer: Optional[AbstractScoreComputer] = None, + proj_dim: int = 2048, + logging_level=logging.INFO, + use_half_precision: bool = True, + proj_max_batch_size: int = 32, + projector_seed: int = 0, + grad_wrt: Optional[Iterable[str]] = None, + ) -> None: """ Args: @@ -92,11 +117,18 @@ def __init__(self, computations and arrays will be stored in float16. Otherwise, it will use float32. Defaults to True. proj_max_batch_size (int): - Batch size used by fast_jl if teh CudaProjector is used. Must be + Batch size used by fast_jl if the CudaProjector is used. Must be a multiple of 8. The maximum batch size is 32 for A100 GPUs, 16 for V100 GPUs, 40 for H100 GPUs. Defaults to 32. - projecotr_seed (int): + projector_seed (int): Random seed used by the projector. Defaults to 0. + grad_wrt (Optional[Iterable[str]], optional): + If not None, the gradients will be computed only with respect to + the parameters specified in this list. The list should contain + the names of the parameters to compute gradients with respect to, + as they appear in the model's state dictionary. If None, + gradients are taken with respect to all model parameters. + Defaults to None. """ @@ -105,23 +137,33 @@ def __init__(self, self.train_set_size = train_set_size self.device = device self.dtype = ch.float16 if use_half_precision else ch.float32 + self.grad_wrt = grad_wrt logging.basicConfig() - self.logger = logging.getLogger('TRAK') + self.logger = logging.getLogger("TRAK") self.logger.setLevel(logging_level) - self.logger.warning('TRAK is still in an early 0.x.x version.\n\ - Report any issues at https://github.com/MadryLab/trak/issues') self.num_params = get_num_params(self.model) + if self.grad_wrt is not None: + d = dict(self.model.named_parameters()) + self.num_params_for_grad = sum( + [d[param_name].numel() for param_name in self.grad_wrt] + ) + else: + self.num_params_for_grad = self.num_params # inits self.projector self.proj_seed = projector_seed - self.init_projector(projector=projector, - proj_dim=proj_dim, - proj_max_batch_size=proj_max_batch_size) + self.init_projector( + projector=projector, + proj_dim=proj_dim, + proj_max_batch_size=proj_max_batch_size, + ) # normalize to make X^TX numerically stable # doing this instead of normalizing the projector matrix - self.normalize_factor = ch.sqrt(ch.tensor(self.num_params, dtype=ch.float32)) + self.normalize_factor = ch.sqrt( + ch.tensor(self.num_params_for_grad, dtype=ch.float32) + ) self.save_dir = Path(save_dir).resolve() self.load_from_save_dir = load_from_save_dir @@ -129,89 +171,169 @@ def __init__(self, if type(self.task) is str: self.task = TASK_TO_MODELOUT[self.task]() - self.gradient_computer = gradient_computer(model=self.model, - task=self.task, - grad_dim=self.num_params, - dtype=self.dtype, - device=self.device) + self.gradient_computer = gradient_computer( + model=self.model, + task=self.task, + grad_dim=self.num_params_for_grad, + dtype=self.dtype, + device=self.device, + grad_wrt=self.grad_wrt, + ) if score_computer is None: score_computer = BasicScoreComputer - self.score_computer = score_computer(dtype=self.dtype, - device=self.device) + self.score_computer = score_computer( + dtype=self.dtype, device=self.device, logging_level=logging_level + ) metadata = { - 'JL dimension': self.proj_dim, - 'JL matrix type': self.projector.proj_type, - 'train set size': self.train_set_size, + "JL dimension": self.proj_dim, + "JL matrix type": self.projector.proj_type, + "train set size": self.train_set_size, } if saver is None: saver = MmapSaver - self.saver = saver(save_dir=self.save_dir, - metadata=metadata, - train_set_size=self.train_set_size, - proj_dim=self.proj_dim, - load_from_save_dir=self.load_from_save_dir, - logging_level=logging_level, - use_half_precision=use_half_precision) - - self.ckpt_loaded = 'no ckpt loaded' - - def init_projector(self, projector, proj_dim, proj_max_batch_size) -> None: - """ Initialize the projector for a traker class + self.saver = saver( + save_dir=self.save_dir, + metadata=metadata, + train_set_size=self.train_set_size, + proj_dim=self.proj_dim, + load_from_save_dir=self.load_from_save_dir, + logging_level=logging_level, + use_half_precision=use_half_precision, + ) + + self.ckpt_loaded = "no ckpt loaded" + + def init_projector( + self, + projector: Optional[AbstractProjector], + proj_dim: int, + proj_max_batch_size: int, + ) -> None: + """Initialize the projector for a traker class Args: - projector (AbstractProjector): - JL projector - + projector (Optional[AbstractProjector]): + JL projector to use. If None, a CudaProjector will be used (if + possible). + proj_dim (int): + Dimension of the projected gradients and TRAK features. + proj_max_batch_size (int): + Batch size used by fast_jl if the CudaProjector is used. Must be + a multiple of 8. The maximum batch size is 32 for A100 GPUs, 16 + for V100 GPUs, 40 for H100 GPUs. """ self.projector = projector if projector is not None: self.proj_dim = self.projector.proj_dim if self.proj_dim == 0: # using NoOpProjector - self.proj_dim = self.num_params + self.proj_dim = self.num_params_for_grad else: + using_cuda_projector = False self.proj_dim = proj_dim - if self.device == 'cpu': - self.logger.info('Using BasicProjector since device is CPU') + if self.device == "cpu": + self.logger.info("Using BasicProjector since device is CPU") projector = BasicProjector # Sampling from bernoulli distribution is not supported for # dtype float16 on CPU; playing it safe here by defaulting to # normal projection, rather than rademacher proj_type = ProjectionType.normal - self.logger.info('Using Normal projection') + self.logger.info("Using Normal projection") else: try: import fast_jl - test_gradient = ch.ones(1, self.num_params).cuda() - num_sms = ch.cuda.get_device_properties('cuda').multi_processor_count - fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms) + + test_gradient = ch.ones(1, self.num_params_for_grad).cuda() + num_sms = ch.cuda.get_device_properties( + "cuda" + ).multi_processor_count + fast_jl.project_rademacher_8( + test_gradient, self.proj_dim, 0, num_sms + ) projector = CudaProjector + using_cuda_projector = True except (ImportError, RuntimeError, AttributeError) as e: - self.logger.error(f'Could not use CudaProjector.\nReason: {str(e)}') - self.logger.error('Defaulting to BasicProjector.') + self.logger.error(f"Could not use CudaProjector.\nReason: {str(e)}") + self.logger.error("Defaulting to BasicProjector.") projector = BasicProjector proj_type = ProjectionType.rademacher - self.logger.debug(f'Initializing projector with grad_dim {self.num_params}') - self.projector = projector(grad_dim=self.num_params, - proj_dim=self.proj_dim, - seed=self.proj_seed, - proj_type=proj_type, - max_batch_size=proj_max_batch_size, - dtype=self.dtype, - device=self.device) - self.logger.debug(f'Initialized projector with proj_dim {self.proj_dim}') - - def load_checkpoint(self, - checkpoint: Iterable[Tensor], - model_id: int, - _allow_featurizing_already_registered=False) -> None: - """ Loads state dictionary for the given checkpoint; initializes arrays + if using_cuda_projector: + max_chunk_size, param_chunk_sizes = get_parameter_chunk_sizes( + self.model, proj_max_batch_size + ) + self.logger.debug( + ( + f"the max chunk size is {max_chunk_size}, ", + "while the model has the following chunk sizes", + f"{param_chunk_sizes}.", + ) + ) + + if ( + len(param_chunk_sizes) > 1 + ): # we have to use the ChunkedCudaProjector + self.logger.info( + ( + f"Using ChunkedCudaProjector with" + f"{len(param_chunk_sizes)} chunks of sizes" + f"{param_chunk_sizes}." + ) + ) + rng = np.random.default_rng(self.proj_seed) + seeds = rng.integers( + low=0, + high=500, + size=len(param_chunk_sizes), + ) + projector_per_chunk = [ + projector( + grad_dim=chunk_size, + proj_dim=self.proj_dim, + seed=seeds[i], + proj_type=ProjectionType.rademacher, + max_batch_size=proj_max_batch_size, + dtype=self.dtype, + device=self.device, + ) + for i, chunk_size in enumerate(param_chunk_sizes) + ] + self.projector = ChunkedCudaProjector( + projector_per_chunk, + max_chunk_size, + param_chunk_sizes, + proj_max_batch_size, + self.device, + self.dtype, + ) + return # do not initialize projector below + + self.logger.debug( + f"Initializing projector with grad_dim {self.num_params_for_grad}" + ) + self.projector = projector( + grad_dim=self.num_params_for_grad, + proj_dim=self.proj_dim, + seed=self.proj_seed, + proj_type=proj_type, + max_batch_size=proj_max_batch_size, + dtype=self.dtype, + device=self.device, + ) + self.logger.debug(f"Initialized projector with proj_dim {self.proj_dim}") + + def load_checkpoint( + self, + checkpoint: Iterable[Tensor], + model_id: int, + _allow_featurizing_already_registered=False, + ) -> None: + """Loads state dictionary for the given checkpoint; initializes arrays to store TRAK features for that checkpoint, tied to the model ID. Args: @@ -226,8 +348,9 @@ def load_checkpoint(self, """ if self.saver.model_ids.get(model_id) is None: - self.saver.register_model_id(model_id, - _allow_featurizing_already_registered) + self.saver.register_model_id( + model_id, _allow_featurizing_already_registered + ) else: self.saver.load_current_store(model_id) @@ -239,12 +362,13 @@ def load_checkpoint(self, self._last_ind = 0 self.ckpt_loaded = model_id - def featurize(self, - batch: Iterable[Tensor], - inds: Optional[Iterable[int]] = None, - num_samples: Optional[int] = None - ) -> None: - """ Creates TRAK features for the given batch by computing the gradient + def featurize( + self, + batch: Iterable[Tensor], + inds: Optional[Iterable[int]] = None, + num_samples: Optional[int] = None, + ) -> None: + """Creates TRAK features for the given batch by computing the gradient of the model output function and projecting it. In the notation of the paper, for an input pair :math:`z=(x,y)`, model parameters :math:`\\theta`, and JL projection matrix :math:`P`, this method @@ -266,12 +390,15 @@ def featurize(self, Number of samples in the batch. Defaults to None. """ - assert self.ckpt_loaded == self.saver.current_model_id, \ - "Load a checkpoint using traker.load_checkpoint before featurizing" - assert (inds is None) or (num_samples is None), \ - "Exactly one of num_samples and inds should be specified" - assert (inds is not None) or (num_samples is not None), \ - "Exactly one of num_samples and inds should be specified" + assert ( + self.ckpt_loaded == self.saver.current_model_id + ), "Load a checkpoint using traker.load_checkpoint before featurizing" + assert (inds is None) or ( + num_samples is None + ), "Exactly one of num_samples and inds should be specified" + assert (inds is not None) or ( + num_samples is not None + ), "Exactly one of num_samples and inds should be specified" if num_samples is not None: inds = np.arange(self._last_ind, self._last_ind + num_samples) @@ -280,27 +407,33 @@ def featurize(self, num_samples = inds.reshape(-1).shape[0] # handle re-starting featurizing from a partially featurized model (some inds already featurized) - _already_done = (self.saver.current_store['is_featurized'][inds] == 1).reshape(-1) + _already_done = (self.saver.current_store["is_featurized"][inds] == 1).reshape( + -1 + ) inds = inds[~_already_done] if len(inds) == 0: - self.logger.debug('All samples in batch already featurized.') + self.logger.debug("All samples in batch already featurized.") return 0 grads = self.gradient_computer.compute_per_sample_grad(batch=batch) grads = self.projector.project(grads, model_id=self.saver.current_model_id) grads /= self.normalize_factor - self.saver.current_store['grads'][inds] = grads.to(self.dtype).cpu().clone().detach() + self.saver.current_store["grads"][inds] = ( + grads.to(self.dtype).cpu().clone().detach() + ) loss_grads = self.gradient_computer.compute_loss_grad(batch) - self.saver.current_store['out_to_loss'][inds] = loss_grads.to(self.dtype).cpu().clone().detach() + self.saver.current_store["out_to_loss"][inds] = ( + loss_grads.to(self.dtype).cpu().clone().detach() + ) - self.saver.current_store['is_featurized'][inds] = 1 + self.saver.current_store["is_featurized"][inds] = 1 self.saver.serialize_current_model_id_metadata() - def finalize_features(self, - model_ids: Iterable[int] = None, - del_grads: bool = False) -> None: - """ For a set of checkpoints :math:`C` (specified by model IDs), and + def finalize_features( + self, model_ids: Iterable[int] = None, del_grads: bool = False + ) -> None: + """For a set of checkpoints :math:`C` (specified by model IDs), and gradients :math:`\\{ \\Phi_c \\}_{c\\in C}`, this method computes :math:`\\Phi_c (\\Phi_c^\\top\\Phi_c)^{-1}` for all :math:`c\\in C` and stores the results in the internal store of the :func:`TRAKer` @@ -313,42 +446,54 @@ def finalize_features(self, class. Defaults to None. """ + + # this method is memory-intensive, so we're freeing memory beforehand + torch.cuda.empty_cache() + self.projector.free_memory() + if model_ids is None: model_ids = list(self.saver.model_ids.keys()) self._last_ind = 0 - for model_id in tqdm(model_ids, desc='Finalizing features for all model IDs..'): + for model_id in tqdm(model_ids, desc="Finalizing features for all model IDs.."): if self.saver.model_ids.get(model_id) is None: - raise ModelIDException(f'Model ID {model_id} not registered, not ready for finalizing.') - elif self.saver.model_ids[model_id]['is_featurized'] == 0: - raise ModelIDException(f'Model ID {model_id} not fully featurized, not ready for finalizing.') - elif self.saver.model_ids[model_id]['is_finalized'] == 1: - self.logger.warning(f'Model ID {model_id} already finalized, skipping .finalize_features for it.') + raise ModelIDException( + f"Model ID {model_id} not registered, not ready for finalizing." + ) + elif self.saver.model_ids[model_id]["is_featurized"] == 0: + raise ModelIDException( + f"Model ID {model_id} not fully featurized, not ready for finalizing." + ) + elif self.saver.model_ids[model_id]["is_finalized"] == 1: + self.logger.warning( + f"Model ID {model_id} already finalized, skipping .finalize_features for it." + ) continue self.saver.load_current_store(model_id) - g = ch.as_tensor(self.saver.current_store['grads']) + g = ch.as_tensor(self.saver.current_store["grads"]) xtx = self.score_computer.get_xtx(g) - self.logger.debug(f'XTX is {xtx}') + self.logger.debug(f"XTX is {xtx}") features = self.score_computer.get_x_xtx_inv(g, xtx) - self.logger.debug(f'Features are {features}') - self.saver.current_store['features'][:] = features.to(self.dtype).cpu() + self.logger.debug(f"Features are {features}") + self.saver.current_store["features"][:] = features.to(self.dtype).cpu() if del_grads: self.saver.del_grads(model_id) - self.saver.model_ids[self.saver.current_model_id]['is_finalized'] = 1 + self.saver.model_ids[self.saver.current_model_id]["is_finalized"] = 1 self.saver.serialize_current_model_id_metadata() - def start_scoring_checkpoint(self, - exp_name: str, - checkpoint: Iterable[Tensor], - model_id: int, - num_targets: int, - ) -> None: - """ This method prepares the internal store of the :class:`.TRAKer` class + def start_scoring_checkpoint( + self, + exp_name: str, + checkpoint: Iterable[Tensor], + model_id: int, + num_targets: int, + ) -> None: + """This method prepares the internal store of the :class:`.TRAKer` class to start computing scores for a set of targets. Args: @@ -376,12 +521,13 @@ def start_scoring_checkpoint(self, # e.g. make it a value in self.saver.experiments[exp_name] self._last_ind_target = 0 - def score(self, - batch: Iterable[Tensor], - inds: Optional[Iterable[int]] = None, - num_samples: Optional[int] = None, - ) -> None: - """ This method computes the (intermediate per-checkpoint) TRAK scores + def score( + self, + batch: Iterable[Tensor], + inds: Optional[Iterable[int]] = None, + num_samples: Optional[int] = None, + ) -> None: + """This method computes the (intermediate per-checkpoint) TRAK scores for a batch of targets and stores them in the internal store of the :class:`.TRAKer` class. @@ -398,13 +544,17 @@ def score(self, Number of samples in the batch. Defaults to None. """ - assert (inds is None) or (num_samples is None), \ - "Exactly one of num_samples and inds should be specified" - assert (inds is not None) or (num_samples is not None), \ - "Exactly one of num_samples and inds should be specified" - - if self.saver.model_ids[self.saver.current_model_id]['is_finalized'] == 0: - self.logger.error(f'Model ID {self.saver.current_model_id} not finalized, cannot score') + assert (inds is None) or ( + num_samples is None + ), "Exactly one of num_samples and inds should be specified" + assert (inds is not None) or ( + num_samples is not None + ), "Exactly one of num_samples and inds should be specified" + + if self.saver.model_ids[self.saver.current_model_id]["is_finalized"] == 0: + self.logger.error( + f"Model ID {self.saver.current_model_id} not finalized, cannot score" + ) return None if num_samples is not None: @@ -419,14 +569,17 @@ def score(self, grads /= self.normalize_factor exp_name = self.saver.current_experiment_name - self.saver.current_store[f'{exp_name}_grads'][inds] = grads.to(self.dtype).cpu().clone().detach() - - def finalize_scores(self, - exp_name: str, - model_ids: Iterable[int] = None, - allow_skip: bool = False, - ) -> Tensor: - """ This method computes the final TRAK scores for the given targets, + self.saver.current_store[f"{exp_name}_grads"][inds] = ( + grads.to(self.dtype).cpu().clone().detach() + ) + + def finalize_scores( + self, + exp_name: str, + model_ids: Iterable[int] = None, + allow_skip: bool = False, + ) -> Tensor: + """This method computes the final TRAK scores for the given targets, train samples, and model checkpoints (specified by model IDs). Args: @@ -456,52 +609,72 @@ def finalize_scores(self, if model_ids is None: model_ids = self.saver.model_ids else: - model_ids = {model_id: self.saver.model_ids[model_id] for model_id in model_ids} - assert len(model_ids) > 0, 'No model IDs to finalize scores for' + model_ids = { + model_id: self.saver.model_ids[model_id] for model_id in model_ids + } + assert len(model_ids) > 0, "No model IDs to finalize scores for" if self.saver.experiments.get(exp_name) is None: - raise ValueError(f'Experiment {exp_name} does not exist. Create it\n\ - and compute scores first before finalizing.') + raise ValueError( + f"Experiment {exp_name} does not exist. Create it\n\ + and compute scores first before finalizing." + ) - num_targets = self.saver.experiments[exp_name]['num_targets'] + num_targets = self.saver.experiments[exp_name]["num_targets"] _completed = [False] * len(model_ids) self.saver.load_current_store(list(model_ids.keys())[0], exp_name, num_targets) - _scores = self.saver.current_store[f'{exp_name}_scores'] - _scores[:] = 0. - - _avg_out_to_losses = np.zeros((self.saver.train_set_size, 1), - dtype=np.float16 if self.dtype == ch.float16 else np.float32) - - for j, model_id in enumerate(tqdm(model_ids, desc='Finalizing scores for all model IDs..')): + _scores_mmap = self.saver.current_store[f"{exp_name}_scores"] + _scores_on_cpu = ch.zeros(*_scores_mmap.shape, device="cpu") + if self.device != "cpu": + _scores_on_cpu.pin_memory() + + _avg_out_to_losses = np.zeros( + (self.saver.train_set_size, 1), + dtype=np.float16 if self.dtype == ch.float16 else np.float32, + ) + + for j, model_id in enumerate( + tqdm(model_ids, desc="Finalizing scores for all model IDs..") + ): self.saver.load_current_store(model_id) try: self.saver.load_current_store(model_id, exp_name, num_targets) except OSError as e: if allow_skip: - self.logger.warning(f'Could not read target gradients for model ID {model_id}. Skipping.') + self.logger.warning( + f"Could not read target gradients for model ID {model_id}. Skipping." + ) continue else: raise e - if self.saver.model_ids[self.saver.current_model_id]['is_finalized'] == 0: - self.logger.warning(f'Model ID {self.saver.current_model_id} not finalized, cannot score') + if self.saver.model_ids[self.saver.current_model_id]["is_finalized"] == 0: + self.logger.warning( + f"Model ID {self.saver.current_model_id} not finalized, cannot score" + ) continue - g = ch.as_tensor(self.saver.current_store['features'], device=self.device) - g_target = ch.as_tensor(self.saver.current_store[f'{exp_name}_grads'], - device=self.device) + g = ch.as_tensor(self.saver.current_store["features"], device=self.device) + g_target = ch.as_tensor( + self.saver.current_store[f"{exp_name}_grads"], device=self.device + ) - _scores[:] += self.score_computer.get_scores(g, g_target).cpu().clone().detach().numpy() + self.score_computer.get_scores(g, g_target, accumulator=_scores_on_cpu) + # .cpu().detach().numpy() - _avg_out_to_losses += self.saver.current_store['out_to_loss'] + _avg_out_to_losses += self.saver.current_store["out_to_loss"] _completed[j] = True _num_models_used = float(sum(_completed)) - _scores[:] = (_scores / _num_models_used) * (_avg_out_to_losses / _num_models_used) - self.logger.debug(f'Scores dtype is {_scores.dtype}') + # only write to mmap (on disk) once at the end + _scores_mmap[:] = (_scores_on_cpu.numpy() / _num_models_used) * ( + _avg_out_to_losses / _num_models_used + ) + + self.logger.debug(f"Scores dtype is {_scores_mmap.dtype}") self.saver.save_scores(exp_name) - self.scores = _scores + self.scores = _scores_mmap return self.scores diff --git a/trak/utils.py b/trak/utils.py index e5a2905..d973eb1 100644 --- a/trak/utils.py +++ b/trak/utils.py @@ -1,6 +1,8 @@ from torch import Tensor import tempfile import torch +import numpy as np + ch = torch @@ -8,7 +10,9 @@ def test_install(use_fast_jl: bool = True): try: from trak import TRAKer except ImportError: - raise ImportError('TRAK is not installed! Please install it using `pip install traker`') + raise ImportError( + "TRAK is not installed! Please install it using `pip install traker`" + ) data = (ch.randn(20, 256), ch.randint(high=2, size=(20,))) model = ch.nn.Linear(256, 2, bias=False) @@ -17,29 +21,34 @@ def test_install(use_fast_jl: bool = True): with tempfile.TemporaryDirectory() as tmpdirname: data = [x.cuda() for x in data] model = model.cuda() - traker = TRAKer(model=model, - task='image_classification', - proj_dim=512, - save_dir=tmpdirname, - train_set_size=20, - logging_level=100) + traker = TRAKer( + model=model, + task="image_classification", + proj_dim=512, + save_dir=tmpdirname, + train_set_size=20, + logging_level=100, + ) traker.load_checkpoint(model.state_dict(), model_id=0) traker.featurize(data, num_samples=20) - print('TRAK and fast_jl are installed correctly!') + print("TRAK and fast_jl are installed correctly!") else: from trak.projectors import NoOpProjector + with tempfile.TemporaryDirectory() as tmpdirname: - traker = TRAKer(model=model, - task='image_classification', - train_set_size=20, - proj_dim=512, - save_dir=tmpdirname, - projector=NoOpProjector(), - device='cpu', - logging_level=100) + traker = TRAKer( + model=model, + task="image_classification", + train_set_size=20, + proj_dim=512, + save_dir=tmpdirname, + projector=NoOpProjector(), + device="cpu", + logging_level=100, + ) traker.load_checkpoint(model.state_dict(), model_id=0) traker.featurize(data, num_samples=20) - print('TRAK is installed correctly!') + print("TRAK is installed correctly!") def parameters_to_vector(parameters) -> Tensor: @@ -59,12 +68,16 @@ def get_num_params(model: torch.nn.Module) -> int: def is_not_buffer(ind, params_dict) -> bool: name = params_dict[ind] - if ('running_mean' in name) or ('running_var' in name) or ('num_batches_tracked' in name): + if ( + ("running_mean" in name) + or ("running_var" in name) + or ("num_batches_tracked" in name) + ): return False return True -def vectorize(g, arr) -> Tensor: +def vectorize(g, arr=None, device="cuda") -> Tensor: """ records result into arr @@ -73,6 +86,15 @@ def vectorize(g, arr) -> Tensor: :code:`grad_wi` has shape :code:`[batch_size, ...]` this function flattens :code:`g` to have shape :code:`[batch_size, num_params]`. """ + if arr is None: + g_elt = g[list(g.keys())[0]] + batch_size = g_elt.shape[0] + num_params = 0 + for param in g.values(): + assert param.shape[0] == batch_size + num_params += int(param.numel() / batch_size) + arr = ch.empty(size=(batch_size, num_params), dtype=g_elt.dtype, device=device) + pointer = 0 for param in g.values(): if len(param.shape) < 2: @@ -82,5 +104,228 @@ def vectorize(g, arr) -> Tensor: num_param = param[0].numel() p = param.flatten(start_dim=1).data - arr[:, pointer:pointer + num_param] = p + arr[:, pointer : pointer + num_param] = p.to(device) pointer += num_param + + return arr + + +def get_output_memory(features: Tensor, target_grads: Tensor, target_dtype: type): + output_shape = features.size(0) * target_grads.size(0) + output_dtype_size = ch.empty((1,), dtype=target_dtype).element_size() + + return output_shape * output_dtype_size + + +def get_free_memory(device): + reserved = ch.cuda.memory_reserved(device=device) + allocated = ch.cuda.memory_allocated(device=device) + + free = reserved - allocated + return free + + +def get_matrix_mult_standard( + features: Tensor, target_grads: Tensor, target_dtype: type +): + output = features @ target_grads.t() + return output.clone().to(target_dtype) + + +def get_matrix_mult_blockwise( + features: Tensor, target_grads: Tensor, target_dtype: type, bs: int +): + s_features = features.shape[0] + s_target_grads = target_grads.shape[0] + + bs = min(s_features, s_target_grads, bs) + + # Copy the data in a pinned memory location to allow non-blocking + # copies to the GPU + features = features.pin_memory() + target_grads = target_grads.pin_memory() + + # precompute all the blocks we will have to compute + slices = [] + for i in range(int(np.ceil(s_features / bs))): + for j in range(int(np.ceil(s_target_grads / bs))): + slices.append((slice(i * bs, (i + 1) * bs), slice(j * bs, (j + 1) * bs))) + + # Allocate memory for the final output. + final_output = ch.empty( + (s_features, s_target_grads), dtype=target_dtype, device="cpu" + ) + + # Output buffers pinned on the CPU to be able to collect data from the + # GPU asynchronously + # For each of our (2) cuda streams we need two output buffer, one + # is currently written on with the next batch of result and the + # second one is already finished and getting copied on the final output + + # If the size is not a multiple of batch size we need extra buffers + # with the proper shapes + outputs = [ + ch.zeros((bs, bs), dtype=target_dtype, device=features.device).pin_memory() + for x in range(4) + ] + left_bottom = s_features % bs + options = [outputs] # List of buffers we can potentially use + if left_bottom: + outputs_target_gradsottom = [ + ch.zeros( + (left_bottom, bs), dtype=target_dtype, device=features.device + ).pin_memory() + for x in range(4) + ] + options.append(outputs_target_gradsottom) + left_right = s_target_grads % bs + if left_right: + outputs_right = [ + ch.zeros( + (bs, left_right), dtype=target_dtype, device=features.device + ).pin_memory() + for x in range(4) + ] + options.append(outputs_right) + if left_right and left_bottom: + outputs_corner = [ + ch.zeros( + (left_bottom, left_right), dtype=target_dtype, device=features.device + ).pin_memory() + for x in range(4) + ] + options.append(outputs_corner) + + streams = [ch.cuda.Stream() for x in range(2)] + + # The slice that was computed last and need to now copied onto the + # final output + previous_slice = None + + def find_buffer_for_shape(shape): + for buff in options: + if buff[0].shape == shape: + return buff + return None + + for i, (slice_i, slice_j) in enumerate(slices): + with ch.cuda.stream(streams[i % len(streams)]): + # Copy the relevant blocks from CPU to the GPU asynchronously + features_i = features[slice_i, :].cuda(non_blocking=True) + target_grads_j = target_grads[slice_j, :].cuda(non_blocking=True) + + output_slice = features_i @ target_grads_j.t() + + find_buffer_for_shape(output_slice.shape)[i % 4].copy_( + output_slice, non_blocking=False + ) + + # Write the previous batch of data from the temporary buffer + # onto the final one (note that this was done by the other stream + # so we swap back to the other one + with ch.cuda.stream(streams[(i + 1) % len(streams)]): + if previous_slice is not None: + output_slice = final_output[previous_slice[0], previous_slice[1]] + output_slice.copy_( + find_buffer_for_shape(output_slice.shape)[(i - 1) % 4], + non_blocking=True, + ) + + previous_slice = (slice_i, slice_j) + + # Wait for all the calculations/copies to be done + ch.cuda.synchronize() + + # Copy the last chunk to the final result (from the appropriate buffer) + output_slice = final_output[previous_slice[0], previous_slice[1]] + output_slice.copy_( + find_buffer_for_shape(output_slice.shape)[i % 4], non_blocking=True + ) + + return final_output + + +def get_matrix_mult( + features: Tensor, + target_grads: Tensor, + target_dtype: torch.dtype = None, + batch_size: int = 8096, + use_blockwise: bool = False, +) -> Tensor: + """ + + Computes features @ target_grads.T. If the output matrix is too large to fit + in memory, it will be computed in blocks. + + Args: + features (Tensor): + The first matrix to multiply. + target_grads (Tensor): + The second matrix to multiply. + target_dtype (torch.dtype, optional): + The dtype of the output matrix. If None, defaults to the dtype of + features. Defaults to None. + batch_size (int, optional): + The batch size to use for blockwise matrix multiplication. Defaults + to 8096. + use_blockwise (bool, optional): + Whether or not to use blockwise matrix multiplication. Defaults to + False. + + """ + if target_dtype is None: + target_dtype = features.dtype + + if use_blockwise: + return get_matrix_mult_blockwise( + features.cpu(), target_grads.cpu(), target_dtype, batch_size + ) + elif features.device.type == "cpu": + return get_matrix_mult_standard(features, target_grads, target_dtype) + + output_memory = get_output_memory(features, target_grads, target_dtype) + free_memory = get_free_memory(features.device) + + if output_memory < free_memory: + return get_matrix_mult_standard(features, target_grads, target_dtype) + else: + return get_matrix_mult_blockwise( + features.cpu(), target_grads.cpu(), target_dtype, batch_size + ) + + +def get_parameter_chunk_sizes( + model: torch.nn.Module, + batch_size: int, +): + """The :class:`CudaProjector` supports projecting when the product of the + number of parameters and the batch size is less than the the max value of + int32. This function computes the number of parameters that can be projected + at once for a given model and batch size. + + The method returns a tuple containing the maximum number of parameters that + can be projected at once and a list of the actual number of parameters in + each chunk (a sequence of paramter groups). Used in + :class:`ChunkedCudaProjector`. + """ + param_shapes = [] + for p in model.parameters(): + param_shapes.append(p.numel()) + + param_shapes = np.array(param_shapes) + + chunk_sum = 0 + max_chunk_size = np.iinfo(np.uint32).max // batch_size + params_per_chunk = [] + + for ps in param_shapes: + if chunk_sum + ps >= max_chunk_size: + params_per_chunk.append(chunk_sum) + chunk_sum = 0 + + chunk_sum += ps + + if param_shapes.sum() - np.sum(params_per_chunk) > 0: + params_per_chunk.append(param_shapes.sum() - np.sum(params_per_chunk)) + + return max_chunk_size, params_per_chunk