Skip to content

Commit 082f78f

Browse files
committed
DOC : Example for fine_tuning BERT models
1 parent 12f6cd7 commit 082f78f

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

docs/examples/fine_tuning.md

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.16.1
8+
---
9+
10+
```{code-cell}
11+
import torch
12+
from medkit.training import TrainerConfig, Trainer
13+
from medkit.text.metrics.ner import SeqEvalMetricsComputer
14+
from medkit.text.ner.hf_entity_matcher import HFEntityMatcher
15+
from medkit.io.medkit_json import load_text_documents
16+
import os
17+
import shutil
18+
19+
train, val, test = [], [], []
20+
21+
#Merge each corpus split into one to get a massive amount of data to fine-tune on
22+
for c in ['quaero','e3c', 'casm2']:
23+
train += list(load_text_documents(f"/content/drive/MyDrive/datasets/{c}/train.jsonl"))
24+
val += list(load_text_documents(f"/content/drive/MyDrive/datasets/{c}/val.jsonl"))
25+
test += list(load_text_documents(f"/content/drive/MyDrive/datasets/{c}/test.jsonl"))
26+
```
27+
28+
```{code-cell}
29+
CHECKPOINT_DIR = "checkpoints_drbert/"
30+
31+
DEVICE = 0 if torch.cuda.is_available() else -1
32+
33+
trainable_matcher = HFEntityMatcher.make_trainable(
34+
model_name_or_path="Dr-BERT/DrBERT-4GB-CP-PubMedBERT",
35+
labels=["ANAT","CHEM","DEVI","DISO","GEOG","LIVB","OBJC","PHEN","PHYS","PROC"],
36+
tagging_scheme="iob2",
37+
tokenizer_max_length=512,
38+
device=DEVICE,
39+
tag_subtokens=True
40+
)
41+
42+
trainer_config = TrainerConfig(
43+
output_dir=CHECKPOINT_DIR,
44+
learning_rate=5e-5,
45+
nb_training_epochs=10,
46+
batch_size=16,
47+
)
48+
49+
ner_metrics_computer = SeqEvalMetricsComputer(
50+
id_to_label=trainable_matcher.id_to_label,
51+
tagging_scheme='iob2',
52+
return_metrics_by_label=False,
53+
average='weighted'
54+
)
55+
56+
trainer = Trainer(
57+
config=trainer_config,
58+
component=trainable_matcher,
59+
train_data=train,
60+
eval_data=val,
61+
metrics_computer=ner_metrics_computer,
62+
)
63+
64+
#Train model
65+
history = trainer.train()
66+
67+
#Get best checkpoint, rename it and save it on my local drive
68+
checkpoint_paths = sorted(glob(CHECKPOINT_DIR + "/checkpoint_*"))
69+
checkpoint_path = checkpoint_paths[0]
70+
os.rename(checkpoint_path, f'{CHECKPOINT_DIR}/DrBert-Generalized')
71+
shutil.move(f'{CHECKPOINT_DIR}/DrBert-Generalized','/content/drive/MyDrive/models')
72+
```
73+
74+
```{code-cell}
75+
CHECKPOINT_DIR = "checkpoints_cam/"
76+
77+
DEVICE = 0 if torch.cuda.is_available() else -1
78+
79+
trainable_matcher = HFEntityMatcher.make_trainable(
80+
model_name_or_path="almanach/camembert-bio-base",
81+
labels=["ANAT","CHEM","DEVI","DISO","GEOG","LIVB","OBJC","PHEN","PHYS","PROC"],
82+
tagging_scheme="iob2",
83+
tokenizer_max_length=512,
84+
device=DEVICE,
85+
tag_subtokens=True
86+
)
87+
88+
trainer_config = TrainerConfig(
89+
output_dir=CHECKPOINT_DIR,
90+
learning_rate=5e-5,
91+
nb_training_epochs=10,
92+
batch_size=16
93+
)
94+
95+
ner_metrics_computer = SeqEvalMetricsComputer(
96+
id_to_label=trainable_matcher.id_to_label,
97+
tagging_scheme='iob2',
98+
return_metrics_by_label=False,
99+
average='weighted'
100+
)
101+
102+
trainer = Trainer(
103+
config=trainer_config,
104+
component=trainable_matcher,
105+
train_data=train,
106+
eval_data=val,
107+
metrics_computer=ner_metrics_computer,
108+
)
109+
110+
#Train model
111+
history = trainer.train()
112+
113+
#Get best checkpoint, rename it and save it on my local drive
114+
checkpoint_paths = sorted(glob(CHECKPOINT_DIR + "/checkpoint_*"))
115+
checkpoint_path = checkpoint_paths[0]
116+
os.rename(checkpoint_path, f'{CHECKPOINT_DIR}/CamemBert-Bio-Generalized')
117+
shutil.move(f'{CHECKPOINT_DIR}/CamemBert-Bio-Generalized','/content/drive/MyDrive/models')
118+
```

0 commit comments

Comments
 (0)