diff --git a/src/tasknet/models.py b/src/tasknet/models.py index b75f9d5..a33937d 100755 --- a/src/tasknet/models.py +++ b/src/tasknet/models.py @@ -5,7 +5,7 @@ from torch.utils.data.dataloader import DataLoader from transformers.data.data_collator import InputDataClass from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler +from torch.utils.data.sampler import RandomSampler, WeightedRandomSampler from typing import List, Union, Dict from transformers import ( AutoModelForSeq2SeqLM, @@ -22,6 +22,7 @@ import functools from types import MappingProxyType from .tasks import Classification +from .utils import to_dict from transformers import AutoTokenizer import magicattr @@ -142,15 +143,18 @@ class MultitaskDataloader: data loaders. """ - def __init__(self, dataloader_dict): + def __init__(self, dataloader_dict, p=1): self.dataloader_dict = dataloader_dict + N=max([len(x)**(1-p) for x in dataloader_dict.values()]) + f_p = lambda x: int(N*x**p) + self.num_batches_dict = { - task_name: len(dataloader) + task_name: f_p(len(dataloader)) for task_name, dataloader in self.dataloader_dict.items() } self.task_name_list = list(self.dataloader_dict) self.dataset = [None] * sum( - len(dataloader.dataset) for dataloader in self.dataloader_dict.values() + f_p(len(dataloader.dataset)) for dataloader in self.dataloader_dict.values() ) def __len__(self): @@ -160,9 +164,6 @@ def __iter__(self): """ For each batch, sample a task, and yield a batch from the respective task Dataloader. - - We use size-proportional sampling, but you could easily modify this - to sample from some-other distribution. """ task_choice_list = [] for i, task_name in enumerate(self.task_name_list): @@ -190,13 +191,9 @@ class default: save_steps = 1000000 label_names = ["labels"] include_inputs_for_metrics = True - - default = { - k: v for (k, v) in default.__dict__.items() if not k.startswith("__") - } - hparams = { - k: v for (k, v) in hparams.__dict__.items() if not k.startswith("__") - } + + default, hparams = to_dict(default), to_dict(hparams) + self.p = hparams.get('p', 1) trainer_args = transformers.TrainingArguments( **{**default, **fc.project(hparams, dir(transformers.TrainingArguments))}, @@ -278,6 +275,12 @@ def evaluate(self, **kwargs): outputs += [output] return fc.join(outputs) + def task_batch_size(self,task_name): + if task_name.task_type=='MultipleChoice': + return min(1, self.args.train_batch_size//4) + else: + return self.args.train_batch_size + def get_single_train_dataloader(self, task_name, train_dataset): """ Create a single-task data loader that also yields task names @@ -294,7 +297,7 @@ def get_single_train_dataloader(self, task_name, train_dataset): task_name=task_name, data_loader=DataLoader( train_dataset, - batch_size=self.args.train_batch_size, + batch_size=self.task_batch_size(task_name), sampler=train_sampler, collate_fn=self.data_collator.__call__, ), @@ -312,7 +315,7 @@ def get_train_dataloader(self): { task_name: self.get_single_train_dataloader(task_name, task_dataset) for task_name, task_dataset in self.train_dataset.items() - } + }, p=self.p, ) def get_eval_dataloader(self, eval_dataset=None): diff --git a/src/tasknet/taskparser.py b/src/tasknet/taskparser.py index 6d8fbe2..5a57dfa 100644 --- a/src/tasknet/taskparser.py +++ b/src/tasknet/taskparser.py @@ -1,4 +1,8 @@ from dataclasses import dataclass +import numpy as np +import funcy as fc +from datasets import Dataset, DatasetDict +import pandas as pd split_mapping={ 'train':['train_split','training'], @@ -43,19 +47,37 @@ def fix_splits(dataset): return dataset fields_mapping={ - 'sentence1':['premise','sentence','sentence1','text','head','question1','question'], - 'sentence2':['hypothesis','sentence2','tail','question2'], + 'sentence1':['premise','sentence','sentence1','text','head','question1','question','sentence_A'], + 'sentence2':['hypothesis','sentence2','tail','question2','sentence_B'], 'labels':['label','labels','relation','gold_label'] - } -def align_fields(x): +def align_fields(dataset): + for k,v in fields_mapping.items(): + bad_fields = [field for field in v if field in dataset['train'].features and field!=k] + if bad_fields: + dataset=dataset.rename_column(bad_fields[0], k) + return dataset + +def align_fields_MultipleChoice(dataset): + fields_mapping={'inputs':['sentence1','question']} for k,v in fields_mapping.items(): - bad_fields = [field for field in v if field in x and field!=k] + bad_fields = [field for field in v if field in dataset['train'].features and field!=k] if bad_fields: - x[k]=x[bad_fields[0]] - del x[bad_fields[0]] - return x + dataset=dataset.rename_column(bad_fields[0], k) + return dataset + +def process_labels(dataset): + if dataset['train'].features['labels'].dtype!='string': + return dataset + + labels=pd.Series(dataset['train']['labels']).value_counts().reset_index() + label_to_index=fc.flip(labels['index'].to_dict()) + def tokenize_labels(x): + x['labels']=label_to_index.get(x['labels'],max(label_to_index.values())+1) + return x + dataset=dataset.map(tokenize_labels) + return dataset def get_name(dataset): return str(dataset.cache_files).split('/.cache/huggingface/datasets/')[-1].split('/')[0] @@ -72,15 +94,18 @@ def task_type(x): return 'MultipleChoice' if x.dataset_name in {'bigbench','blimp','hendrycks_test'}: return 'MultipleChoice' - if x.dataset_name in {'glue','anli','tweet_eval','pragmeval','relbert/lexical_relation_classification','metaeval/linguisticprobing'}: + if x.dataset_name in {'glue','anli','tweet_eval','pragmeval', + 'relbert/lexical_relation_classification','metaeval/linguisticprobing', + 'paws','lex_glue','sick','snips_built_in_intents','discovery','ethos','imppres'}: return 'Classification' if x.dataset_name in {'conll2003'}: return 'TokenClassification' @dataclass class TaskParser: - #todo sick - def normalize_anli(dataset): + max_choices:int=None + #todo: sick + def normalize_anli(self, dataset): l=[] for i in '123': split=[f'train_r{i}',f'dev_r{i}',f'test_r{i}'] @@ -90,7 +115,7 @@ def normalize_anli(dataset): l+=[align_splits(ds)] return l - def normalize_conll2003(dataset): + def normalize_conll2003(self, dataset): l=[] for y in ['pos_tags', 'chunk_tags', 'ner_tags']: ds=dataset.rename_column('pos_tags','labels') @@ -98,17 +123,17 @@ def normalize_conll2003(dataset): l+=[ds] return l - def normalize_blimp(dataset): + def normalize_blimp(self, dataset): def add_label(x): x['label']=0 - x['s1']='' + x['inputs']='' return x dataset=dataset.map(add_label).\ rename_column('sentence_good','choice0').\ rename_column('sentence_bad','choice1') return dataset - def normalize_hendrycks_test(dataset): + def normalize_hendrycks_test(self, dataset): def reformat(x): for i in range(4): x[f'choice{i}']=x['choices'][i] @@ -116,7 +141,7 @@ def reformat(x): return x return dataset.map(reformat).rename_column('answer','labels') - def normalize_bigbench(dataset): + def normalize_bigbench(self, dataset): try: minimum_answer_counts=min( @@ -124,7 +149,7 @@ def normalize_bigbench(dataset): for ds in dataset.values() ] ) - assert minimum_answer_counts + assert minimum_answer_counts<9 print('minimum_answer_counts:',minimum_answer_counts) except: raise ValueError('Unsupported bigbench format') @@ -140,7 +165,8 @@ def cap_options(x,n=None): return x def reformat(x): - x=cap_options(x,minimum_answer_counts-1) + n_options= self.max_choices if self.max_choices else (0,minimum_answer_counts-1) + x=cap_options(x,n_options) x['labels']=np.argmax(x['multiple_choice_scores']) for i,o in enumerate(x['multiple_choice_targets']): x[f'choice{i}']=o @@ -148,15 +174,22 @@ def reformat(x): dataset= dataset.map(reformat) dataset=dataset_deduplicate(dataset,subset=['inputs','choice0']) return dataset - def parse(dataset,dataset_name=None): + + def parse(self, dataset,dataset_name=None, task_type=None): if not dataset_name: dataset_name=get_name(dataset) print('name:',dataset_name) - if hasattr(TaskParser, f'normalize_{dataset_name}'): - dataset=getattr(TaskParser, f'normalize_{dataset_name}')(dataset) - if type(dataset)==list: - dataset=dataset[0] - dataset=align_splits(dataset) - dataset=fix_splits(dataset) - dataset=dataset.map(align_fields) - return dataset \ No newline at end of file + if hasattr(self, f'normalize_{dataset_name}'): + dataset=getattr(self, f'normalize_{dataset_name}')(dataset) + if type(dataset)!=list: + datasets=[dataset] + l=[] + for dataset in datasets: + dataset=align_splits(dataset) + dataset=fix_splits(dataset) + dataset=align_fields(dataset) + dataset=process_labels(dataset) + if task_type=='MultipleChoice': + dataset=align_fields_MultipleChoice(dataset) + l+=[dataset] + return l \ No newline at end of file diff --git a/src/tasknet/tasks.py b/src/tasknet/tasks.py index 24536d7..2f1890d 100755 --- a/src/tasknet/tasks.py +++ b/src/tasknet/tasks.py @@ -15,24 +15,31 @@ from dataclasses import dataclass import re from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from frozendict import frozendict - +import inspect load_dataset = lazy_func(datasets.load_dataset) -def get_name(dataset): +def get_dataset_name(dataset): try: - s = str(dataset.cache_files.values()) - return re.search(r"/datasets/(.*?)/default/", s).group(1).split("___")[-1] + s="/".join(dataset.cache_files['train'][0]['filename'].split('/huggingface/datasets/')[-1].split('/')[:-3]) + return s except: return "" +def sample_dataset(dataset,n=10000, n_eval=1000): + for k in dataset: + n_k=(n if k=='train' else n_eval) + if n_k and len(dataset[k])>n_k: + dataset[k]=dataset[k].select(range(n_k)) + return dataset @dataclass class Task: dataset: Dataset = None name: str = "" tokenizer: PreTrainedTokenizerBase = None - tokenizer_kwargs: ... = fdict(padding="max_length", max_length=256) + tokenizer_kwargs: ... = fdict(padding="max_length", max_length=256,truncation=True) + max_rows:int=None + max_rows_eval:int=None def __hash__(self): return hash(str(self.dataset.__dict__)) @@ -47,10 +54,12 @@ def __post_init__(self): name = "/".join(self.dataset) self.dataset = load_dataset(*self.dataset) else: - name = get_name(self.dataset) + name = get_dataset_name(self.dataset) if not self.name: self.name = name + self.dataset=sample_dataset(self.dataset,self.max_rows,self.max_rows_eval) + def set_tokenizer(self, tokenizer): self.tokenizer = tokenizer @@ -70,7 +79,12 @@ def __post_init__(self): super().__post_init__() if not self.num_labels: target = self.dataset["train"].features[self.y] - self.num_labels = 1 if "float" in target.dtype else target.num_classes + if "float" in target.dtype: + self.num_labels = 1 + elif hasattr(target,'num_classes'): + self.num_labels=target.num_classes + else: + self.num_labels=len(set(self.dataset['train'][self.y])) def preprocess_function(self, examples): inputs = ( @@ -110,7 +124,8 @@ def __call__(self, features): ] flattened_features = sum(flattened_features, []) - batch = self.tokenizer.pad(flattened_features, **self.tokenizer_kwargs) + pad_args=inspect.signature(self.tokenizer.pad).parameters.keys() + batch = self.tokenizer.pad(flattened_features, **fc.project(self.tokenizer_kwargs,pad_args)) # Un-flatten batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()} @@ -131,7 +146,9 @@ class MultipleChoice(Classification): def __post_init__(self): super().__post_init__() self.data_collator.tokenizer_kwargs = self.tokenizer_kwargs - + choices = [x for x in self.dataset['train'].features if re.match('choice\d+',x)] + if choices and not self.choices: + self.choices=choices def set_tokenizer(self, tokenizer): self.tokenizer = self.data_collator.tokenizer= tokenizer @@ -149,7 +166,7 @@ def preprocess_function(self, examples): # Tokenize tokenized_examples = self.tokenizer( - first_sentences, second_sentences, truncation=True + first_sentences, second_sentences, **self.tokenizer_kwargs ) # Un-flatten @@ -279,8 +296,21 @@ def preprocess_function(self, batch): def _explode(result,prefix=''): return {f'{prefix}{k}_{a}_{b}'.replace("_mid","").replace("_fmeasure",""):round(getattr(getattr(v,b),a)*100,3)\ for (k,v) in result.items() for a in ['precision','recall','fmeasure'] for b in ['low','mid','high']} - + + @classmethod + def _postprocess_text(preds, labels): + import nltk + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + def compute_metrics(self, eval_preds): + preds, labels = eval_preds if isinstance(preds, tuple): preds = preds[0] decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) @@ -289,7 +319,7 @@ def compute_metrics(self, eval_preds): decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing - decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) + decoded_preds, decoded_labels = self._postprocess_text(decoded_preds, decoded_labels) g = decoded_preds, decoded_labels result = self.metric.compute( predictions=decoded_preds, references=decoded_labels, use_stemmer=True diff --git a/src/tasknet/utils.py b/src/tasknet/utils.py index 77cbee9..95a0c9f 100755 --- a/src/tasknet/utils.py +++ b/src/tasknet/utils.py @@ -1,4 +1,5 @@ -from datasets import DatasetDict +from datasets import DatasetDict, Dataset, load_dataset +from easydict import EasyDict as edict def train_validation_test_split(dataset, train_ratio=0.8, val_test_ratio=0.5): @@ -10,3 +11,16 @@ def train_validation_test_split(dataset, train_ratio=0.8, val_test_ratio=0.5): test=test_valid["train"], ) return dataset + + +def load_dataset_sample(*args,n=1000): + ds= load_dataset(*args,streaming=True) + return DatasetDict({k: Dataset.from_list(list(ds[k].shuffle().take(n))) for k in ds}) + + +def to_dict(x): + if hasattr(x,'items'): + return edict(x) + else: + x=edict({a:getattr(x,a) for a in dir(x) if not a.startswith('__')}) + return x \ No newline at end of file