Skip to content

Commit

Permalink
qnli test; formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kristian-georgiev committed Oct 25, 2023
1 parent 50d40ec commit ead038b
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 14 deletions.
19 changes: 9 additions & 10 deletions examples/qnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

from argparse import ArgumentParser
import sys
from tqdm import tqdm

import torch as ch
Expand All @@ -21,7 +20,6 @@

# Huggingface
from datasets import load_dataset
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
Expand All @@ -30,7 +28,6 @@
)



GLUE_TASK_TO_KEYS = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
Expand Down Expand Up @@ -76,8 +73,8 @@ def __init__(self):

def forward(self, input_ids, token_type_ids, attention_mask):
return self.model(input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask).logits
token_type_ids=token_type_ids,
attention_mask=attention_mask).logits


def get_dataset(split, inds=None):
Expand All @@ -88,10 +85,9 @@ def get_dataset(split, inds=None):
use_auth_token=None,
)
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli']

label_to_id = None #{v: i for i, v in enumerate(label_list)}
label_to_id = None # {v: i for i, v in enumerate(label_list)}

tokenizer = AutoTokenizer.from_pretrained(
'bert-base-cased',
Expand All @@ -102,7 +98,7 @@ def get_dataset(split, inds=None):
)

padding = "max_length"
max_seq_length=128
max_seq_length = 128

def preprocess_function(examples):
# Tokenize the texts
Expand All @@ -113,7 +109,7 @@ def preprocess_function(examples):

# Map labels to IDs (not necessary for GLUE tasks)
if label_to_id is not None and "label" in examples:
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
result["label"] = [(label_to_id[lbl] if lbl != -1 else -1) for lbl in examples["label"]]
return result

raw_datasets = raw_datasets.map(
Expand Down Expand Up @@ -178,7 +174,10 @@ def process_batch(batch):

traker.finalize_features()

traker.start_scoring_checkpoint(exp_name='qnli', checkpoint=model.state_dict(), model_id=0, num_targets=VAL_SET_SIZE)
traker.start_scoring_checkpoint(exp_name='qnli',
checkpoint=model.state_dict(),
model_id=0,
num_targets=VAL_SET_SIZE)
for batch in tqdm(loader_val, desc='Scoring..'):
batch = process_batch(batch)
batch = [x.cuda() for x in batch]
Expand Down
1 change: 0 additions & 1 deletion tests/test_integration_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,3 @@ def test_cifar10(tmp_path, device='cpu'):
@pytest.mark.cuda
def test_cifar10_cuda(tmp_path):
test_cifar10(tmp_path, device='cuda:0')

165 changes: 165 additions & 0 deletions tests/test_integration_qnli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
import pytest
import logging

from trak import TRAKer

from datasets import load_dataset
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
default_data_collator,
)


GLUE_TASK_TO_KEYS = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}

# for testing purposes
TRAIN_SET_SIZE = 20
VAL_SET_SIZE = 10


class SequenceClassificationModel(nn.Module):
"""
Wrapper for HuggingFace sequence classification models.
"""
def __init__(self):
super().__init__()
self.config = AutoConfig.from_pretrained(
'bert-base-cased',
num_labels=2,
finetuning_task='qnli',
cache_dir=None,
revision='main',
use_auth_token=None,
)

self.model = AutoModelForSequenceClassification.from_pretrained(
'bert-base-cased',
config=self.config,
cache_dir=None,
revision='main',
use_auth_token=None,
ignore_mismatched_sizes=False
)

self.model.eval().cuda()

def forward(self, input_ids, token_type_ids, attention_mask):
return self.model(input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask).logits


def get_dataset(split, inds=None):
raw_datasets = load_dataset(
"glue",
'qnli',
cache_dir=None,
use_auth_token=None,
)
sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli']

tokenizer = AutoTokenizer.from_pretrained(
'bert-base-cased',
cache_dir=None,
use_fast=True,
revision='main',
use_auth_token=False
)

padding = "max_length"
max_seq_length = 128

def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)

return result

raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
load_from_cache_file=(not False),
desc="Running tokenizer on dataset",
)

if split == 'train':
train_dataset = raw_datasets["train"]
ds = train_dataset
else:
eval_dataset = raw_datasets["validation"]
ds = eval_dataset
return ds


def init_loaders(batch_size=10):
ds_train = get_dataset('train')
ds_train = ds_train.select(range(TRAIN_SET_SIZE))
ds_val = get_dataset('val')
ds_val = ds_val.select(range(VAL_SET_SIZE))
return DataLoader(ds_train, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator), \
DataLoader(ds_val, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator)


def process_batch(batch):
return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels']


# model too large to test on CPU
@pytest.mark.cuda
def test_qnli(tmp_path, device='cuda'):
loader_train, loader_val = init_loaders()

# no need to load model from checkpoint, just testing featurization and scoring
model = SequenceClassificationModel()

logger = logging.getLogger('QNLI')
logger.setLevel(logging.DEBUG)
logger.info(f'Initializing TRAKer with device {device}')

traker = TRAKer(model=model,
task='text_classification',
train_set_size=TRAIN_SET_SIZE,
save_dir=tmp_path,
device=device,
logging_level=logging.DEBUG,
proj_dim=512)

logger.info('Loading checkpoint')
traker.load_checkpoint(model.state_dict(), model_id=0)
logger.info('Loaded checkpoint')
for batch in tqdm(loader_train, desc='Featurizing..'):
# process batch into compatible form for TRAKer TextClassificationModelOutput
batch = process_batch(batch)
batch = [x.to(device) for x in batch]
traker.featurize(batch=batch, num_samples=batch[0].shape[0])

traker.finalize_features()

traker.start_scoring_checkpoint(exp_name='qnli',
checkpoint=model.state_dict(),
model_id=0,
num_targets=VAL_SET_SIZE)
for batch in tqdm(loader_val, desc='Scoring..'):
batch = process_batch(batch)
batch = [x.to(device) for x in batch]
traker.score(batch=batch, num_samples=batch[0].shape[0])

traker.finalize_scores(exp_name='qnli')
12 changes: 9 additions & 3 deletions trak/projectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,15 @@ class BasicProjector(AbstractProjector):
a CUDA-enabled device with compute capability >=7.0 (see
https://developer.nvidia.com/cuda-gpus).
"""
def __init__(self, grad_dim: int, proj_dim: int, seed: int, proj_type:
ProjectionType, device, block_size: int = 200, dtype=ch.float32,
model_id=0, *args, **kwargs) -> None:
def __init__(self, grad_dim: int,
proj_dim: int,
seed: int,
proj_type: ProjectionType,
device: torch.device,
block_size: int = 100,
dtype: torch.dtype = ch.float32,
model_id=0,
*args, **kwargs) -> None:
super().__init__(grad_dim, proj_dim, seed, proj_type, device)

self.block_size = min(self.proj_dim, block_size)
Expand Down

0 comments on commit ead038b

Please sign in to comment.