Skip to content

Commit

Permalink
0.1.1 (#17)
Browse files Browse the repository at this point in the history
* minor fix

* make loading from save_dir optional

* sesh

* upload on dropbox

* some memory profiling tests

* bump version

* fix typo

* text classification model output

* qnli example

---------

Co-authored-by: sung-max <[email protected]>
  • Loading branch information
kristian-georgiev and sung-max authored Mar 23, 2023
1 parent 9dd570a commit 918d546
Show file tree
Hide file tree
Showing 13 changed files with 406 additions and 50 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,6 @@ dmypy.json
# results and logs
trak_results/
*.out

# session
Session.vim
187 changes: 187 additions & 0 deletions examples/qnli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""
Example applying TRAK to language models finetuned for text classification.
Dataset: GLUE QNLI
Model: bert-base-cased (https://huggingface.co/bert-base-cased)
Tokenizers and loaders are adapted from the Hugging Face example
(https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification).
"""

from argparse import ArgumentParser
import sys
from tqdm import tqdm

import torch as ch
import torch.nn as nn
from torch.utils.data import DataLoader

from trak import TRAKer

# Huggingface
from datasets import load_dataset
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
default_data_collator,
)



GLUE_TASK_TO_KEYS = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}

# NOTE: CHANGE THIS IF YOU WANT TO RUN ON FULL DATASET
TRAIN_SET_SIZE = 5_000
VAL_SET_SIZE = 1_00


class SequenceClassificationModel(nn.Module):
"""
Wrapper for HuggingFace sequence classification models.
"""
def __init__(self):
super().__init__()
self.config = AutoConfig.from_pretrained(
'bert-base-cased',
num_labels=2,
finetuning_task='qnli',
cache_dir=None,
revision='main',
use_auth_token=None,
)

self.model = AutoModelForSequenceClassification.from_pretrained(
'bert-base-cased',
config=self.config,
cache_dir=None,
revision='main',
use_auth_token=None,
ignore_mismatched_sizes=False
)

self.model.eval().cuda()

def forward(self, input_ids, token_type_ids, attention_mask):
return self.model(input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask).logits


def get_dataset(split, inds=None):
raw_datasets = load_dataset(
"glue",
'qnli',
cache_dir=None,
use_auth_token=None,
)
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli']

label_to_id = None #{v: i for i, v in enumerate(label_list)}

tokenizer = AutoTokenizer.from_pretrained(
'bert-base-cased',
cache_dir=None,
use_fast=True,
revision='main',
use_auth_token=False
)

padding = "max_length"
max_seq_length=128

def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)

# Map labels to IDs (not necessary for GLUE tasks)
if label_to_id is not None and "label" in examples:
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
return result

raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
load_from_cache_file=(not False),
desc="Running tokenizer on dataset",
)

if split == 'train':
train_dataset = raw_datasets["train"]
ds = train_dataset
else:
eval_dataset = raw_datasets["validation"]
ds = eval_dataset
return ds


def init_model(ckpt_path, device='cuda'):
model = SequenceClassificationModel()
sd = ch.load(ckpt_path)
model.model.load_state_dict(sd)
return model


def init_loaders(batch_size=16):
ds_train = get_dataset('train')
ds_train = ds_train.select(range(TRAIN_SET_SIZE))
ds_val = get_dataset('val')
ds_val = ds_val.select(range(VAL_SET_SIZE))
return DataLoader(ds_train, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator), \
DataLoader(ds_val, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator)


def process_batch(batch):
return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels']


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--ckpt', type=str, help='model checkpoint', required=True)
parser.add_argument('--out', type=str, help='dir to save TRAK scores and metadata to', required=True)
args = parser.parse_args()

device = 'cuda'
loader_train, loader_val = init_loaders()
model = init_model(args.ckpt, device)

traker = TRAKer(model=model,
task='text_classification',
train_set_size=TRAIN_SET_SIZE,
save_dir=args.out,
device=device,
proj_dim=1024)

traker.load_checkpoint(model.state_dict(), model_id=0)
for batch in tqdm(loader_train, desc='Featurizing..'):
# process batch into compatible form for TRAKer TextClassificationModelOutput
batch = process_batch(batch)
batch = [x.cuda() for x in batch]
traker.featurize(batch=batch, num_samples=batch[0].shape[0])

traker.finalize_features()

traker.start_scoring_checkpoint(model.state_dict(), model_id=0, num_targets=VAL_SET_SIZE)
for batch in tqdm(loader_val, desc='Scoring..'):
batch = process_batch(batch)
batch = [x.cuda() for x in batch]
traker.score(batch=batch, num_samples=batch[0].shape[0])

scores = traker.finalize_scores()
5 changes: 5 additions & 0 deletions examples/slurm_example/featurize_and_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ def main(model_id):
ckpt = model.state_dict()
# ==================================

# the job with JOB ID is the one responsible for metadata.json
# all other jobs do not attempt to read/write to metadata.json
should_load_from_save_dir = (model_id == 0)

traker = TRAKer(model=model,
task='image_classification',
save_dir='./slurm_example_results',
load_from_save_dir=should_load_from_save_dir,
train_set_size=len(ds_train),
device='cuda')

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from setuptools import setup

setup(name="traker",
version="0.1.0",
version="0.1.1",
description="TRAK: Attributing Model Behavior at Scale",
long_description="Check https://trak.csail.mit.edu/ to learn more about TRAK",
author="MadryLab",
Expand Down
66 changes: 66 additions & 0 deletions tests/memory_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from tqdm import tqdm
from pathlib import Path
from pytorch_memlab import LineProfiler, MemReporter
from trak import TRAKer
import torch
ch = torch

from utils import construct_rn9, get_dataloader, eval_correlations
from utils import download_cifar_checkpoints, download_cifar_betons


def test_cifar_acc(serialize=False, dtype=ch.float32, batch_size=100, tmp_path='/tmp/trak_results/'):
device = 'cuda:0'
model = construct_rn9().to(memory_format=ch.channels_last).to(device)
model = model.eval()

BETONS_PATH = Path(tmp_path).joinpath('cifar_betons')
BETONS = download_cifar_betons(BETONS_PATH)

loader_train = get_dataloader(BETONS, batch_size=batch_size, split='train')
loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val')

CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts')
ckpt_files = download_cifar_checkpoints(CKPT_PATH)
ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files]

reporter = MemReporter()

traker = TRAKer(model=model,
task='image_classification',
proj_dim=1024,
train_set_size=10_000,
save_dir=tmp_path,
device=device)

for model_id, ckpt in enumerate(ckpts):
i = 0

traker.load_checkpoint(checkpoint=ckpt, model_id=model_id)

for batch in tqdm(loader_train, desc='Computing TRAK embeddings...'):
traker.featurize(batch=batch, num_samples=len(batch[0]))
reporter.report()

traker.finalize_features()

if serialize:
del traker
traker = TRAKer(model=model,
task='image_classification',
proj_dim=1024,
train_set_size=10_000,
save_dir=tmp_path,
device=device)

for model_id, ckpt in enumerate(ckpts):
traker.start_scoring_checkpoint(ckpt, model_id, num_targets=2_000)
for batch in tqdm(loader_val, desc='Scoring...'):
traker.score(batch=batch, num_samples=len(batch[0]))

scores = traker.finalize_scores().cpu()

with LineProfiler(test_cifar_acc, TRAKer.featurize, TRAKer.load_checkpoint) as prof:
test_cifar_acc()

prof.print_stats()
15 changes: 9 additions & 6 deletions tests/test_cifar_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from trak.projectors import BasicProjector

from .utils import construct_rn9, get_dataloader, eval_correlations
from .utils import download_cifar_checkpoints, download_cifar_betons


def get_projector(use_cuda_projector, dtype):
Expand All @@ -33,12 +34,14 @@ def test_cifar_acc(serialize, use_cuda_projector, dtype, batch_size, tmp_path):
model = construct_rn9().to(memory_format=ch.channels_last).to(device)
model = model.eval()

loader_train = get_dataloader(batch_size=batch_size, split='train')
loader_val = get_dataloader(batch_size=batch_size, split='val')
BETONS_PATH = Path(tmp_path).joinpath('cifar_betons')
BETONS = download_cifar_betons(BETONS_PATH)

# TODO: put this on dropbox as well
CKPT_PATH = '/mnt/xfs/projects/trak/checkpoints/resnet9_cifar2/debug'
ckpt_files = list(Path(CKPT_PATH).rglob("*.pt"))
loader_train = get_dataloader(BETONS, batch_size=batch_size, split='train')
loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val')

CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts')
ckpt_files = download_cifar_checkpoints(CKPT_PATH)
ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files]

traker = TRAKer(model=model,
Expand Down Expand Up @@ -74,4 +77,4 @@ def test_cifar_acc(serialize, use_cuda_projector, dtype, batch_size, tmp_path):
scores = traker.finalize_scores().cpu()

avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path)
assert avg_corr > 0.058, 'correlation with 3 CIFAR-2 models should be >= 0.058'
assert avg_corr > 0.062, 'correlation with 3 CIFAR-2 models should be >= 0.062'
4 changes: 2 additions & 2 deletions tests/test_integration_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def test_mscoco(tmp_path, device='cuda:0'):

tokenizer = open_clip.get_tokenizer('RN50')

ds_train = datasets.CocoCaptions(root='/mnt/xfs/projects/trak/datasets/coco_csv/train2014',
annFile='/mnt/xfs/projects/trak/datasets/coco_csv/coco_train_karpathy.json'
ds_train = datasets.CocoCaptions(root='/path/to/coco_csv/train2014',
annFile='/path/to/coco_csv/coco_train_karpathy.json'
)

traker = TRAKer(model=model,
Expand Down
29 changes: 17 additions & 12 deletions tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from trak import TRAKer
from .utils import construct_rn9, get_dataloader, eval_correlations
from .utils import download_cifar_checkpoints, download_cifar_betons


@pytest.mark.cuda
Expand All @@ -15,12 +16,14 @@ def test_featurize_and_score_in_parallel(tmp_path):
model = construct_rn9().to(memory_format=ch.channels_last).to(device)
model = model.eval()

loader_train = get_dataloader(batch_size=batch_size, split='train')
loader_val = get_dataloader(batch_size=batch_size, split='val')
BETONS_PATH = Path(tmp_path).joinpath('cifar_betons')
BETONS = download_cifar_betons(BETONS_PATH)

# TODO: put this on dropbox as well
CKPT_PATH = '/mnt/xfs/projects/trak/checkpoints/resnet9_cifar2/debug'
ckpt_files = list(Path(CKPT_PATH).rglob("*.pt"))
loader_train = get_dataloader(BETONS, batch_size=batch_size, split='train')
loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val')

CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts')
ckpt_files = download_cifar_checkpoints(CKPT_PATH)
ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files]

# this should be essentially equivalent to running each
Expand Down Expand Up @@ -50,7 +53,7 @@ def test_featurize_and_score_in_parallel(tmp_path):
scores = traker.finalize_scores().cpu()

avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path)
assert avg_corr > 0.058, 'correlation with 3 CIFAR-2 models should be >= 0.058'
assert avg_corr > 0.062, 'correlation with 3 CIFAR-2 models should be >= 0.062'


@pytest.mark.cuda
Expand All @@ -61,12 +64,14 @@ def test_score_multiple(tmp_path):
model = construct_rn9().to(memory_format=ch.channels_last).to(device)
model = model.eval()

loader_train = get_dataloader(batch_size=batch_size, split='train')
loader_val = get_dataloader(batch_size=batch_size, split='val')
BETONS_PATH = Path(tmp_path).joinpath('cifar_betons')
BETONS = download_cifar_betons(BETONS_PATH)

loader_train = get_dataloader(BETONS, batch_size=batch_size, split='train')
loader_val = get_dataloader(BETONS, batch_size=batch_size, split='val')

# TODO: put this on dropbox as well
CKPT_PATH = '/mnt/xfs/projects/trak/checkpoints/resnet9_cifar2/debug'
ckpt_files = list(Path(CKPT_PATH).rglob("*.pt"))
CKPT_PATH = Path(tmp_path).joinpath('cifar_ckpts')
ckpt_files = download_cifar_checkpoints(CKPT_PATH)
ckpts = [ch.load(ckpt, map_location='cpu') for ckpt in ckpt_files]

traker = TRAKer(model=model,
Expand Down Expand Up @@ -97,4 +102,4 @@ def test_score_multiple(tmp_path):
scores = traker.finalize_scores().cpu()

avg_corr = eval_correlations(infls=scores, tmp_path=tmp_path)
assert avg_corr > 0.058, 'correlation with 3 CIFAR-2 models should be >= 0.058'
assert avg_corr > 0.062, 'correlation with 3 CIFAR-2 models should be >= 0.062'
Loading

0 comments on commit 918d546

Please sign in to comment.