diff --git a/examples/qnli.py b/examples/qnli.py index b8faa37..da9e48f 100644 --- a/examples/qnli.py +++ b/examples/qnli.py @@ -10,7 +10,6 @@ """ from argparse import ArgumentParser -import sys from tqdm import tqdm import torch as ch @@ -21,7 +20,6 @@ # Huggingface from datasets import load_dataset -import transformers from transformers import ( AutoConfig, AutoModelForSequenceClassification, @@ -30,7 +28,6 @@ ) - GLUE_TASK_TO_KEYS = { "cola": ("sentence", None), "mnli": ("premise", "hypothesis"), @@ -76,8 +73,8 @@ def __init__(self): def forward(self, input_ids, token_type_ids, attention_mask): return self.model(input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask).logits + token_type_ids=token_type_ids, + attention_mask=attention_mask).logits def get_dataset(split, inds=None): @@ -88,10 +85,9 @@ def get_dataset(split, inds=None): use_auth_token=None, ) label_list = raw_datasets["train"].features["label"].names - num_labels = len(label_list) sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli'] - label_to_id = None #{v: i for i, v in enumerate(label_list)} + label_to_id = None # {v: i for i, v in enumerate(label_list)} tokenizer = AutoTokenizer.from_pretrained( 'bert-base-cased', @@ -102,7 +98,7 @@ def get_dataset(split, inds=None): ) padding = "max_length" - max_seq_length=128 + max_seq_length = 128 def preprocess_function(examples): # Tokenize the texts @@ -113,7 +109,7 @@ def preprocess_function(examples): # Map labels to IDs (not necessary for GLUE tasks) if label_to_id is not None and "label" in examples: - result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + result["label"] = [(label_to_id[lbl] if lbl != -1 else -1) for lbl in examples["label"]] return result raw_datasets = raw_datasets.map( @@ -178,7 +174,10 @@ def process_batch(batch): traker.finalize_features() - traker.start_scoring_checkpoint(exp_name='qnli', checkpoint=model.state_dict(), model_id=0, num_targets=VAL_SET_SIZE) + traker.start_scoring_checkpoint(exp_name='qnli', + checkpoint=model.state_dict(), + model_id=0, + num_targets=VAL_SET_SIZE) for batch in tqdm(loader_val, desc='Scoring..'): batch = process_batch(batch) batch = [x.cuda() for x in batch] diff --git a/tests/test_integration_cifar.py b/tests/test_integration_cifar.py index 8ba5b99..cad31c1 100644 --- a/tests/test_integration_cifar.py +++ b/tests/test_integration_cifar.py @@ -49,4 +49,3 @@ def test_cifar10(tmp_path, device='cpu'): @pytest.mark.cuda def test_cifar10_cuda(tmp_path): test_cifar10(tmp_path, device='cuda:0') - diff --git a/tests/test_integration_qnli.py b/tests/test_integration_qnli.py new file mode 100644 index 0000000..58364a3 --- /dev/null +++ b/tests/test_integration_qnli.py @@ -0,0 +1,165 @@ +from tqdm import tqdm +from torch.utils.data import DataLoader +import torch.nn as nn +import pytest +import logging + +from trak import TRAKer + +from datasets import load_dataset +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + default_data_collator, +) + + +GLUE_TASK_TO_KEYS = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +# for testing purposes +TRAIN_SET_SIZE = 20 +VAL_SET_SIZE = 10 + + +class SequenceClassificationModel(nn.Module): + """ + Wrapper for HuggingFace sequence classification models. + """ + def __init__(self): + super().__init__() + self.config = AutoConfig.from_pretrained( + 'bert-base-cased', + num_labels=2, + finetuning_task='qnli', + cache_dir=None, + revision='main', + use_auth_token=None, + ) + + self.model = AutoModelForSequenceClassification.from_pretrained( + 'bert-base-cased', + config=self.config, + cache_dir=None, + revision='main', + use_auth_token=None, + ignore_mismatched_sizes=False + ) + + self.model.eval().cuda() + + def forward(self, input_ids, token_type_ids, attention_mask): + return self.model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask).logits + + +def get_dataset(split, inds=None): + raw_datasets = load_dataset( + "glue", + 'qnli', + cache_dir=None, + use_auth_token=None, + ) + sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli'] + + tokenizer = AutoTokenizer.from_pretrained( + 'bert-base-cased', + cache_dir=None, + use_fast=True, + revision='main', + use_auth_token=False + ) + + padding = "max_length" + max_seq_length = 128 + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) + + return result + + raw_datasets = raw_datasets.map( + preprocess_function, + batched=True, + load_from_cache_file=(not False), + desc="Running tokenizer on dataset", + ) + + if split == 'train': + train_dataset = raw_datasets["train"] + ds = train_dataset + else: + eval_dataset = raw_datasets["validation"] + ds = eval_dataset + return ds + + +def init_loaders(batch_size=10): + ds_train = get_dataset('train') + ds_train = ds_train.select(range(TRAIN_SET_SIZE)) + ds_val = get_dataset('val') + ds_val = ds_val.select(range(VAL_SET_SIZE)) + return DataLoader(ds_train, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator), \ + DataLoader(ds_val, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator) + + +def process_batch(batch): + return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels'] + + +# model too large to test on CPU +@pytest.mark.cuda +def test_qnli(tmp_path, device='cuda'): + loader_train, loader_val = init_loaders() + + # no need to load model from checkpoint, just testing featurization and scoring + model = SequenceClassificationModel() + + logger = logging.getLogger('QNLI') + logger.setLevel(logging.DEBUG) + logger.info(f'Initializing TRAKer with device {device}') + + traker = TRAKer(model=model, + task='text_classification', + train_set_size=TRAIN_SET_SIZE, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + proj_dim=512) + + logger.info('Loading checkpoint') + traker.load_checkpoint(model.state_dict(), model_id=0) + logger.info('Loaded checkpoint') + for batch in tqdm(loader_train, desc='Featurizing..'): + # process batch into compatible form for TRAKer TextClassificationModelOutput + batch = process_batch(batch) + batch = [x.to(device) for x in batch] + traker.featurize(batch=batch, num_samples=batch[0].shape[0]) + + traker.finalize_features() + + traker.start_scoring_checkpoint(exp_name='qnli', + checkpoint=model.state_dict(), + model_id=0, + num_targets=VAL_SET_SIZE) + for batch in tqdm(loader_val, desc='Scoring..'): + batch = process_batch(batch) + batch = [x.to(device) for x in batch] + traker.score(batch=batch, num_samples=batch[0].shape[0]) + + traker.finalize_scores(exp_name='qnli') diff --git a/trak/projectors.py b/trak/projectors.py index fbeaff9..d736bfa 100644 --- a/trak/projectors.py +++ b/trak/projectors.py @@ -158,9 +158,15 @@ class BasicProjector(AbstractProjector): a CUDA-enabled device with compute capability >=7.0 (see https://developer.nvidia.com/cuda-gpus). """ - def __init__(self, grad_dim: int, proj_dim: int, seed: int, proj_type: - ProjectionType, device, block_size: int = 200, dtype=ch.float32, - model_id=0, *args, **kwargs) -> None: + def __init__(self, grad_dim: int, + proj_dim: int, + seed: int, + proj_type: ProjectionType, + device: torch.device, + block_size: int = 100, + dtype: torch.dtype = ch.float32, + model_id=0, + *args, **kwargs) -> None: super().__init__(grad_dim, proj_dim, seed, proj_type, device) self.block_size = min(self.proj_dim, block_size)