Skip to content

Commit

Permalink
more multi-task options
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Dec 2, 2022
1 parent 7f2b08e commit 29e8bd4
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 57 deletions.
35 changes: 19 additions & 16 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data.dataloader import DataLoader
from transformers.data.data_collator import InputDataClass
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.sampler import RandomSampler, WeightedRandomSampler
from typing import List, Union, Dict
from transformers import (
AutoModelForSeq2SeqLM,
Expand All @@ -22,6 +22,7 @@
import functools
from types import MappingProxyType
from .tasks import Classification
from .utils import to_dict
from transformers import AutoTokenizer
import magicattr

Expand Down Expand Up @@ -142,15 +143,18 @@ class MultitaskDataloader:
data loaders.
"""

def __init__(self, dataloader_dict):
def __init__(self, dataloader_dict, p=1):
self.dataloader_dict = dataloader_dict
N=max([len(x)**(1-p) for x in dataloader_dict.values()])
f_p = lambda x: int(N*x**p)

self.num_batches_dict = {
task_name: len(dataloader)
task_name: f_p(len(dataloader))
for task_name, dataloader in self.dataloader_dict.items()
}
self.task_name_list = list(self.dataloader_dict)
self.dataset = [None] * sum(
len(dataloader.dataset) for dataloader in self.dataloader_dict.values()
f_p(len(dataloader.dataset)) for dataloader in self.dataloader_dict.values()
)

def __len__(self):
Expand All @@ -160,9 +164,6 @@ def __iter__(self):
"""
For each batch, sample a task, and yield a batch from the respective
task Dataloader.
We use size-proportional sampling, but you could easily modify this
to sample from some-other distribution.
"""
task_choice_list = []
for i, task_name in enumerate(self.task_name_list):
Expand Down Expand Up @@ -190,13 +191,9 @@ class default:
save_steps = 1000000
label_names = ["labels"]
include_inputs_for_metrics = True

default = {
k: v for (k, v) in default.__dict__.items() if not k.startswith("__")
}
hparams = {
k: v for (k, v) in hparams.__dict__.items() if not k.startswith("__")
}

default, hparams = to_dict(default), to_dict(hparams)
self.p = hparams.get('p', 1)

trainer_args = transformers.TrainingArguments(
**{**default, **fc.project(hparams, dir(transformers.TrainingArguments))},
Expand Down Expand Up @@ -278,6 +275,12 @@ def evaluate(self, **kwargs):
outputs += [output]
return fc.join(outputs)

def task_batch_size(self,task_name):
if task_name.task_type=='MultipleChoice':
return min(1, self.args.train_batch_size//4)
else:
return self.args.train_batch_size

def get_single_train_dataloader(self, task_name, train_dataset):
"""
Create a single-task data loader that also yields task names
Expand All @@ -294,7 +297,7 @@ def get_single_train_dataloader(self, task_name, train_dataset):
task_name=task_name,
data_loader=DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
batch_size=self.task_batch_size(task_name),
sampler=train_sampler,
collate_fn=self.data_collator.__call__,
),
Expand All @@ -312,7 +315,7 @@ def get_train_dataloader(self):
{
task_name: self.get_single_train_dataloader(task_name, task_dataset)
for task_name, task_dataset in self.train_dataset.items()
}
}, p=self.p,
)

def get_eval_dataloader(self, eval_dataset=None):
Expand Down
87 changes: 60 additions & 27 deletions src/tasknet/taskparser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from dataclasses import dataclass
import numpy as np
import funcy as fc
from datasets import Dataset, DatasetDict
import pandas as pd

split_mapping={
'train':['train_split','training'],
Expand Down Expand Up @@ -43,19 +47,37 @@ def fix_splits(dataset):
return dataset

fields_mapping={
'sentence1':['premise','sentence','sentence1','text','head','question1','question'],
'sentence2':['hypothesis','sentence2','tail','question2'],
'sentence1':['premise','sentence','sentence1','text','head','question1','question','sentence_A'],
'sentence2':['hypothesis','sentence2','tail','question2','sentence_B'],
'labels':['label','labels','relation','gold_label']

}

def align_fields(x):
def align_fields(dataset):
for k,v in fields_mapping.items():
bad_fields = [field for field in v if field in dataset['train'].features and field!=k]
if bad_fields:
dataset=dataset.rename_column(bad_fields[0], k)
return dataset

def align_fields_MultipleChoice(dataset):
fields_mapping={'inputs':['sentence1','question']}
for k,v in fields_mapping.items():
bad_fields = [field for field in v if field in x and field!=k]
bad_fields = [field for field in v if field in dataset['train'].features and field!=k]
if bad_fields:
x[k]=x[bad_fields[0]]
del x[bad_fields[0]]
return x
dataset=dataset.rename_column(bad_fields[0], k)
return dataset

def process_labels(dataset):
if dataset['train'].features['labels'].dtype!='string':
return dataset

labels=pd.Series(dataset['train']['labels']).value_counts().reset_index()
label_to_index=fc.flip(labels['index'].to_dict())
def tokenize_labels(x):
x['labels']=label_to_index.get(x['labels'],max(label_to_index.values())+1)
return x
dataset=dataset.map(tokenize_labels)
return dataset

def get_name(dataset):
return str(dataset.cache_files).split('/.cache/huggingface/datasets/')[-1].split('/')[0]
Expand All @@ -72,15 +94,18 @@ def task_type(x):
return 'MultipleChoice'
if x.dataset_name in {'bigbench','blimp','hendrycks_test'}:
return 'MultipleChoice'
if x.dataset_name in {'glue','anli','tweet_eval','pragmeval','relbert/lexical_relation_classification','metaeval/linguisticprobing'}:
if x.dataset_name in {'glue','anli','tweet_eval','pragmeval',
'relbert/lexical_relation_classification','metaeval/linguisticprobing',
'paws','lex_glue','sick','snips_built_in_intents','discovery','ethos','imppres'}:
return 'Classification'
if x.dataset_name in {'conll2003'}:
return 'TokenClassification'

@dataclass
class TaskParser:
#todo sick
def normalize_anli(dataset):
max_choices:int=None
#todo: sick
def normalize_anli(self, dataset):
l=[]
for i in '123':
split=[f'train_r{i}',f'dev_r{i}',f'test_r{i}']
Expand All @@ -90,41 +115,41 @@ def normalize_anli(dataset):
l+=[align_splits(ds)]
return l

def normalize_conll2003(dataset):
def normalize_conll2003(self, dataset):
l=[]
for y in ['pos_tags', 'chunk_tags', 'ner_tags']:
ds=dataset.rename_column('pos_tags','labels')
setattr(ds,'task_config',f'label:{y}')
l+=[ds]
return l

def normalize_blimp(dataset):
def normalize_blimp(self, dataset):
def add_label(x):
x['label']=0
x['s1']=''
x['inputs']=''
return x
dataset=dataset.map(add_label).\
rename_column('sentence_good','choice0').\
rename_column('sentence_bad','choice1')
return dataset

def normalize_hendrycks_test(dataset):
def normalize_hendrycks_test(self, dataset):
def reformat(x):
for i in range(4):
x[f'choice{i}']=x['choices'][i]
del x['choices']
return x
return dataset.map(reformat).rename_column('answer','labels')

def normalize_bigbench(dataset):
def normalize_bigbench(self, dataset):

try:
minimum_answer_counts=min(
[ds.with_format("pandas")["multiple_choice_targets"].map(len).min()
for ds in dataset.values()
]
)
assert minimum_answer_counts
assert minimum_answer_counts<9
print('minimum_answer_counts:',minimum_answer_counts)
except:
raise ValueError('Unsupported bigbench format')
Expand All @@ -140,23 +165,31 @@ def cap_options(x,n=None):
return x

def reformat(x):
x=cap_options(x,minimum_answer_counts-1)
n_options= self.max_choices if self.max_choices else (0,minimum_answer_counts-1)
x=cap_options(x,n_options)
x['labels']=np.argmax(x['multiple_choice_scores'])
for i,o in enumerate(x['multiple_choice_targets']):
x[f'choice{i}']=o
return x
dataset= dataset.map(reformat)
dataset=dataset_deduplicate(dataset,subset=['inputs','choice0'])
return dataset
def parse(dataset,dataset_name=None):

def parse(self, dataset,dataset_name=None, task_type=None):
if not dataset_name:
dataset_name=get_name(dataset)
print('name:',dataset_name)
if hasattr(TaskParser, f'normalize_{dataset_name}'):
dataset=getattr(TaskParser, f'normalize_{dataset_name}')(dataset)
if type(dataset)==list:
dataset=dataset[0]
dataset=align_splits(dataset)
dataset=fix_splits(dataset)
dataset=dataset.map(align_fields)
return dataset
if hasattr(self, f'normalize_{dataset_name}'):
dataset=getattr(self, f'normalize_{dataset_name}')(dataset)
if type(dataset)!=list:
datasets=[dataset]
l=[]
for dataset in datasets:
dataset=align_splits(dataset)
dataset=fix_splits(dataset)
dataset=align_fields(dataset)
dataset=process_labels(dataset)
if task_type=='MultipleChoice':
dataset=align_fields_MultipleChoice(dataset)
l+=[dataset]
return l
56 changes: 43 additions & 13 deletions src/tasknet/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,31 @@
from dataclasses import dataclass
import re
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from frozendict import frozendict

import inspect
load_dataset = lazy_func(datasets.load_dataset)

def get_name(dataset):
def get_dataset_name(dataset):
try:
s = str(dataset.cache_files.values())
return re.search(r"/datasets/(.*?)/default/", s).group(1).split("___")[-1]
s="/".join(dataset.cache_files['train'][0]['filename'].split('/huggingface/datasets/')[-1].split('/')[:-3])
return s
except:
return ""

def sample_dataset(dataset,n=10000, n_eval=1000):
for k in dataset:
n_k=(n if k=='train' else n_eval)
if n_k and len(dataset[k])>n_k:
dataset[k]=dataset[k].select(range(n_k))
return dataset

@dataclass
class Task:
dataset: Dataset = None
name: str = ""
tokenizer: PreTrainedTokenizerBase = None
tokenizer_kwargs: ... = fdict(padding="max_length", max_length=256)
tokenizer_kwargs: ... = fdict(padding="max_length", max_length=256,truncation=True)
max_rows:int=None
max_rows_eval:int=None

def __hash__(self):
return hash(str(self.dataset.__dict__))
Expand All @@ -47,10 +54,12 @@ def __post_init__(self):
name = "/".join(self.dataset)
self.dataset = load_dataset(*self.dataset)
else:
name = get_name(self.dataset)
name = get_dataset_name(self.dataset)

if not self.name:
self.name = name
self.dataset=sample_dataset(self.dataset,self.max_rows,self.max_rows_eval)


def set_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
Expand All @@ -70,7 +79,12 @@ def __post_init__(self):
super().__post_init__()
if not self.num_labels:
target = self.dataset["train"].features[self.y]
self.num_labels = 1 if "float" in target.dtype else target.num_classes
if "float" in target.dtype:
self.num_labels = 1
elif hasattr(target,'num_classes'):
self.num_labels=target.num_classes
else:
self.num_labels=len(set(self.dataset['train'][self.y]))

def preprocess_function(self, examples):
inputs = (
Expand Down Expand Up @@ -110,7 +124,8 @@ def __call__(self, features):
]
flattened_features = sum(flattened_features, [])

batch = self.tokenizer.pad(flattened_features, **self.tokenizer_kwargs)
pad_args=inspect.signature(self.tokenizer.pad).parameters.keys()
batch = self.tokenizer.pad(flattened_features, **fc.project(self.tokenizer_kwargs,pad_args))

# Un-flatten
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
Expand All @@ -131,7 +146,9 @@ class MultipleChoice(Classification):
def __post_init__(self):
super().__post_init__()
self.data_collator.tokenizer_kwargs = self.tokenizer_kwargs

choices = [x for x in self.dataset['train'].features if re.match('choice\d+',x)]
if choices and not self.choices:
self.choices=choices
def set_tokenizer(self, tokenizer):
self.tokenizer = self.data_collator.tokenizer= tokenizer

Expand All @@ -149,7 +166,7 @@ def preprocess_function(self, examples):

# Tokenize
tokenized_examples = self.tokenizer(
first_sentences, second_sentences, truncation=True
first_sentences, second_sentences, **self.tokenizer_kwargs
)

# Un-flatten
Expand Down Expand Up @@ -279,8 +296,21 @@ def preprocess_function(self, batch):
def _explode(result,prefix=''):
return {f'{prefix}{k}_{a}_{b}'.replace("_mid","").replace("_fmeasure",""):round(getattr(getattr(v,b),a)*100,3)\
for (k,v) in result.items() for a in ['precision','recall','fmeasure'] for b in ['low','mid','high']}


@classmethod
def _postprocess_text(preds, labels):
import nltk
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]

# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

return preds, labels

def compute_metrics(self, eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
Expand All @@ -289,7 +319,7 @@ def compute_metrics(self, eval_preds):
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)

# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
decoded_preds, decoded_labels = self._postprocess_text(decoded_preds, decoded_labels)
g = decoded_preds, decoded_labels
result = self.metric.compute(
predictions=decoded_preds, references=decoded_labels, use_stemmer=True
Expand Down
Loading

0 comments on commit 29e8bd4

Please sign in to comment.