Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Jan 6, 2023
1 parent d57ebe1 commit a244fc1
Show file tree
Hide file tree
Showing 4 changed files with 543 additions and 50 deletions.
96 changes: 58 additions & 38 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,65 +8,88 @@
from torch.utils.data.sampler import RandomSampler, WeightedRandomSampler
from typing import List, Union, Dict
from transformers import (
AutoModelForSeq2SeqLM,
EncoderDecoderModel,
DataCollatorForSeq2Seq,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForMultipleChoice,
AutoModelForTokenClassification,
)
from transformers import EncoderDecoderModel
from easydict import EasyDict as edict
import funcy as fc
import copy
import logging
import functools
from types import MappingProxyType
from .tasks import Classification
from .utils import to_dict
from .utils import to_dict, shallow_copy_A_to_B, deep_copy_cache, normalize_label
from transformers import AutoTokenizer
import magicattr
import gc
import random

def progress(l):
try:
from tqdm.auto import tqdm
assert len(l)>8
return tqdm(l)
except:
return l


class CLSEmbedding(nn.Module):
def __init__(self, Zi):
def __init__(self, Zi, drop_probability=0.0):
super().__init__()
self.cls = Zi

self.drop_probability=drop_probability
def forward(self, x):
x[:, 0, :] = x[:, 0, :] + self.cls
if random.random()>self.drop_probability:
x[:, 0, :] = x[:, 0, :] + self.cls
return x

class WandbTaskCallback(transformers.integrations.WandbCallback):

def on_log(self, args, state, control, model=None, logs=None, **kwargs):
import wandb
if not self._initialized:
self.setup(args, state, model, reinit=False)
if state.is_world_process_zero:
if 'eval_name' in logs:
logs={f"{logs['eval_name']}/{k}" :v for (k,v) in logs.items() if k!="eval_name"}
wandb.log(logs, step=state.global_step)

class Model(transformers.PreTrainedModel):
def __init__(self, tasks, args, warm_start=None):
super().__init__(transformers.PretrainedConfig())
args=to_dict(args)
self.shared_encoder = warm_start
mc_model = None
self.models={}
task_models_list = []
for i, task in enumerate(tasks):
for i, task in progress(enumerate(tasks)):
model_type = eval(f"AutoModelFor{task.task_type}")
nl = {a: getattr(task, a) for a in ('num_labels','problem_type')
if hasattr(task, a)
}

model = model_type.from_pretrained(args.model_name, **nl)
model = deep_copy_cache(model_type.from_pretrained)(args.model_name, **nl)

if task.task_type=='MultipleChoice':
key="mc"
key=task.task_type
else:
labels = getattr(task.dataset['train'].features[task.y],"names",None)
key=(tuple(labels) if labels else None)
key= tuple([normalize_label(x) for x in labels]) if labels else None
key = key if task.num_labels!=2 or key else "binary"

if key and key not in self.models:
self.models[key] = model
if key and key in self.models:
self.shallow_copy(self.models[key].classifier, model.classifier)
model.classifier.weight = self.models[key].classifier.weight

model.auto = getattr(model, self.get_encoder_attr_name(model))

if self.shared_encoder is None:
self.shared_encoder = model.auto
else:
self.shallow_copy(self.shared_encoder, model.auto)
shallow_copy_A_to_B(self.shared_encoder, model.auto)

task_models_list += [model]
model.i = i
Expand All @@ -85,32 +108,18 @@ def __init__(self, tasks, args, warm_start=None):
emb_name, emb_module = [(name,module) for name,module in m_i.named_modules() if isinstance(module,torch.nn.Embedding)][0]

magicattr.set(m_i, emb_name,
nn.Sequential(emb_module, CLSEmbedding(self.Z[i]))
nn.Sequential(emb_module,
CLSEmbedding(
self.Z[i],
drop_probability=args.get('cls_emb_drop_probability',0.0))
)
)
torch.cuda.empty_cache()
gc.collect()

def set_encoder(self,encoder):
for model in self.task_models_list:
self.shallow_copy(encoder, getattr(model, self.get_encoder_attr_name(model)))


@staticmethod
def shallow_copy(A, B):
"""Shallow copy (=parameter sharing) A into B
https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427"""

def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition(".")
return setattr(rgetattr(obj, pre) if pre else obj, post, val)

def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)

return functools.reduce(_getattr, [obj] + attr.split("."))

for (na, _), (nb, _) in zip(A.named_parameters(), B.named_parameters()):
rsetattr(B, nb, rgetattr(A, na))
return A, B
shallow_copy_A_to_B(encoder, getattr(model, self.get_encoder_attr_name(model)))

@classmethod
def get_encoder_attr_name(cls, model):
Expand Down Expand Up @@ -249,8 +258,18 @@ class default:
# transformerS.Trainer recognizes eval_dataset instances of "dict"
# But we use a custom "evaluate" function so that we can use different metrics for each task
self.eval_dataset = MappingProxyType(self.eval_dataset)
self.fix_callback()
self.cleanup_outputs()

def fix_callback(self):
try:
import wandb
except:
return
i=[i for (i,c) in enumerate(self.callback_handler.callbacks) if 'Wandb' in str(c)]
if i:
self.callback_handler.callbacks[i[0]] = WandbTaskCallback()

@staticmethod
def cleanup_outputs():
try:
Expand All @@ -274,8 +293,9 @@ def write_line(other, values):

def evaluate(self, **kwargs):
try:
self.callback_handler.callbacks[-1].training_tracker.write_line = fc.partial(
self.write_line, self.callback_handler.callbacks[-1].training_tracker
i=[i for (i,c) in enumerate(self.callback_handler.callbacks) if 'NotebookProgress' in str(c)][0]
self.callback_handler.callbacks[i].training_tracker.write_line = fc.partial(
self.write_line, self.callback_handler.callbacks[i].training_tracker
)
except:
logging.info('No training_tracker')
Expand Down
31 changes: 20 additions & 11 deletions src/tasknet/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,23 @@ def get_dataset_name(dataset):
except:
return ""

def sample_dataset(dataset,n=10000, n_eval=1000):
def oversample(dataset, n=2):
dataset['train']= datasets.concatenate_datasets(
[dataset['train'].shuffle(_) for _ in range(n)]
)
return dataset

def sample_dataset(dataset,n=10000, n_eval=1000, oversampling=None):
if oversampling and len(dataset['train'])<n:
dataset=oversample(dataset, oversampling)

for k in dataset:
n_k=(n if k=='train' else n_eval)
if n_k and len(dataset[k])>n_k:
dataset[k]=dataset[k].train_test_split(train_size=n_k)['train']
return dataset


@dataclass
class Task:
dataset: Dataset = None
Expand All @@ -42,6 +52,7 @@ class Task:
tokenizer_kwargs: ... = fdict(padding="max_length", max_length=256,truncation=True)
max_rows:int=None
max_rows_eval:int=None
oversampling:int=None

def __hash__(self):
return hash(str(self.dataset.__dict__))
Expand All @@ -60,7 +71,8 @@ def __post_init__(self):

if not self.name:
self.name = name
self.dataset=sample_dataset(self.dataset,self.max_rows,self.max_rows_eval)
self.results=[]
self.dataset=sample_dataset(self.dataset,self.max_rows,self.max_rows_eval, self.oversampling)

def check():
return True
Expand Down Expand Up @@ -130,7 +142,9 @@ def compute_metrics(self, eval_pred):
else:
metric = load_metric("glue", "stsb")
meta = {"name": self.name, "size": len(predictions), "index": self.index}
return {**metric.compute(predictions=predictions, references=labels,**avg), **meta}
metrics = metric.compute(predictions=predictions, references=labels,**avg)
self.results+=[metrics]
return {**metrics, **meta}


@dataclass
Expand Down Expand Up @@ -283,14 +297,9 @@ def compute_metrics(self, eval_pred):
predictions=true_predictions, references=true_labels
)
meta = {"name": self.name, "size": len(predictions), "index": self.index}

return {
"precision": all_metrics["overall_precision"],
"recall": all_metrics["overall_recall"],
"f1": all_metrics["overall_f1"],
"accuracy": all_metrics["overall_accuracy"],
**meta,
}
metrics = {k.replace("overall_",""):v for k,v in all_metrics.items() if "overall" in k}
self.results+=[metrics]
return {**metrics, **meta}

def check(self):
features = self.dataset['train'].features
Expand Down
44 changes: 43 additions & 1 deletion src/tasknet/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datasets import DatasetDict, Dataset, load_dataset
from easydict import EasyDict as edict
import copy
import functools


def train_validation_test_split(dataset, train_ratio=0.8, val_test_ratio=0.5, seed=0):
Expand All @@ -23,4 +25,44 @@ def to_dict(x):
return edict(x)
else:
x=edict({a:getattr(x,a) for a in dir(x) if not a.startswith('__')})
return x
return x

def deep_copy_cache(function):
memo = {}
def wrapper(*args, **kwargs):
if args in memo:
return copy.deepcopy(memo[args])
else:
rv = function(*args, **kwargs)
memo[args] = rv
return rv
return wrapper

def shallow_copy_A_to_B(A, B):
"""Shallow copy (=parameter sharing) A into B
https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427"""

def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition(".")
return setattr(rgetattr(obj, pre) if pre else obj, post, val)

def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)

return functools.reduce(_getattr, [obj] + attr.split("."))

for (na, _), (nb, _) in zip(A.named_parameters(), B.named_parameters()):
rsetattr(B, nb, rgetattr(A, na))
return A, B

def normalize_label(label):
label=str(label).lower()
label=label.replace('-','_')
label=label.replace(' ','_')
label=label.replace('entailed', 'entailment')
label=label.replace('non_','not_')
label=label.replace('duplicate','equivalent')
label=label.replace('neg','negative')
label=label.replace('pos','positive')
return label
Loading

0 comments on commit a244fc1

Please sign in to comment.