|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + text_representation: |
| 4 | + extension: .md |
| 5 | + format_name: myst |
| 6 | + format_version: 0.13 |
| 7 | + jupytext_version: 1.16.1 |
| 8 | +--- |
| 9 | + |
| 10 | +```{code-cell} |
| 11 | +import torch |
| 12 | +from medkit.training import TrainerConfig, Trainer |
| 13 | +from medkit.text.metrics.ner import SeqEvalMetricsComputer |
| 14 | +from medkit.text.ner.hf_entity_matcher import HFEntityMatcher |
| 15 | +from medkit.io.medkit_json import load_text_documents |
| 16 | +import os |
| 17 | +import shutil |
| 18 | +
|
| 19 | +train, val, test = [], [], [] |
| 20 | +
|
| 21 | +#Merge each corpus split into one to get a massive amount of data to fine-tune on |
| 22 | +for c in ['quaero','e3c', 'casm2']: |
| 23 | + train += list(load_text_documents(f"/content/drive/MyDrive/datasets/{c}/train.jsonl")) |
| 24 | + val += list(load_text_documents(f"/content/drive/MyDrive/datasets/{c}/val.jsonl")) |
| 25 | + test += list(load_text_documents(f"/content/drive/MyDrive/datasets/{c}/test.jsonl")) |
| 26 | +``` |
| 27 | + |
| 28 | +```{code-cell} |
| 29 | +CHECKPOINT_DIR = "checkpoints_drbert/" |
| 30 | +
|
| 31 | +DEVICE = 0 if torch.cuda.is_available() else -1 |
| 32 | +
|
| 33 | +trainable_matcher = HFEntityMatcher.make_trainable( |
| 34 | + model_name_or_path="Dr-BERT/DrBERT-4GB-CP-PubMedBERT", |
| 35 | + labels=["ANAT","CHEM","DEVI","DISO","GEOG","LIVB","OBJC","PHEN","PHYS","PROC"], |
| 36 | + tagging_scheme="iob2", |
| 37 | + tokenizer_max_length=512, |
| 38 | + device=DEVICE, |
| 39 | + tag_subtokens=True |
| 40 | +) |
| 41 | +
|
| 42 | +trainer_config = TrainerConfig( |
| 43 | + output_dir=CHECKPOINT_DIR, |
| 44 | + learning_rate=5e-5, |
| 45 | + nb_training_epochs=10, |
| 46 | + batch_size=16, |
| 47 | +) |
| 48 | +
|
| 49 | +ner_metrics_computer = SeqEvalMetricsComputer( |
| 50 | + id_to_label=trainable_matcher.id_to_label, |
| 51 | + tagging_scheme='iob2', |
| 52 | + return_metrics_by_label=False, |
| 53 | + average='weighted' |
| 54 | +) |
| 55 | +
|
| 56 | +trainer = Trainer( |
| 57 | + config=trainer_config, |
| 58 | + component=trainable_matcher, |
| 59 | + train_data=train, |
| 60 | + eval_data=val, |
| 61 | + metrics_computer=ner_metrics_computer, |
| 62 | +) |
| 63 | +
|
| 64 | +#Train model |
| 65 | +history = trainer.train() |
| 66 | +
|
| 67 | +#Get best checkpoint, rename it and save it on my local drive |
| 68 | +checkpoint_paths = sorted(glob(CHECKPOINT_DIR + "/checkpoint_*")) |
| 69 | +checkpoint_path = checkpoint_paths[0] |
| 70 | +os.rename(checkpoint_path, f'{CHECKPOINT_DIR}/DrBert-Generalized') |
| 71 | +shutil.move(f'{CHECKPOINT_DIR}/DrBert-Generalized','/content/drive/MyDrive/models') |
| 72 | +``` |
| 73 | + |
| 74 | +```{code-cell} |
| 75 | +CHECKPOINT_DIR = "checkpoints_cam/" |
| 76 | +
|
| 77 | +DEVICE = 0 if torch.cuda.is_available() else -1 |
| 78 | +
|
| 79 | +trainable_matcher = HFEntityMatcher.make_trainable( |
| 80 | + model_name_or_path="almanach/camembert-bio-base", |
| 81 | + labels=["ANAT","CHEM","DEVI","DISO","GEOG","LIVB","OBJC","PHEN","PHYS","PROC"], |
| 82 | + tagging_scheme="iob2", |
| 83 | + tokenizer_max_length=512, |
| 84 | + device=DEVICE, |
| 85 | + tag_subtokens=True |
| 86 | +) |
| 87 | +
|
| 88 | +trainer_config = TrainerConfig( |
| 89 | + output_dir=CHECKPOINT_DIR, |
| 90 | + learning_rate=5e-5, |
| 91 | + nb_training_epochs=10, |
| 92 | + batch_size=16 |
| 93 | +) |
| 94 | +
|
| 95 | +ner_metrics_computer = SeqEvalMetricsComputer( |
| 96 | + id_to_label=trainable_matcher.id_to_label, |
| 97 | + tagging_scheme='iob2', |
| 98 | + return_metrics_by_label=False, |
| 99 | + average='weighted' |
| 100 | +) |
| 101 | +
|
| 102 | +trainer = Trainer( |
| 103 | + config=trainer_config, |
| 104 | + component=trainable_matcher, |
| 105 | + train_data=train, |
| 106 | + eval_data=val, |
| 107 | + metrics_computer=ner_metrics_computer, |
| 108 | +) |
| 109 | +
|
| 110 | +#Train model |
| 111 | +history = trainer.train() |
| 112 | +
|
| 113 | +#Get best checkpoint, rename it and save it on my local drive |
| 114 | +checkpoint_paths = sorted(glob(CHECKPOINT_DIR + "/checkpoint_*")) |
| 115 | +checkpoint_path = checkpoint_paths[0] |
| 116 | +os.rename(checkpoint_path, f'{CHECKPOINT_DIR}/CamemBert-Bio-Generalized') |
| 117 | +shutil.move(f'{CHECKPOINT_DIR}/CamemBert-Bio-Generalized','/content/drive/MyDrive/models') |
| 118 | +``` |
0 commit comments