This package has been developed so researchers can easily use state-of-the-art transformer models and large language models for disease onset prediction from electronic health records (EHR).
A visualization technique is also implemented in this package, providing insight into important narrative features driving the prediction. This package provides end to end process from training to prediction.
The package is built on top of the Transformers developed by the HuggingFace. We have the requirement.txt to specify the packages required to run the project.
- GatorTron-base (345m parameters) https://huggingface.co/UFNLP/gatortron-base
- GatorTron-medium (3.9b parameters) https://huggingface.co/UFNLP/gatortron-medium
- GatorTron-large (8.9b parameters) https://huggingface.co/UFNLP/gatortron-large
- BERT
- XLNet
- RoBERTa
- ALBERT
- DeBERTa
- Longformer
We will keep adding new models.
- Prerequisite
To use the package for text classification, you need to provide unstructured text data with labels.
Preprocessing is required to truncate the text under the required maximum token limitation for different models and ensure it is in the correct format. Once the preprocessing is complete, you can run the package to obtain end-to-end classification prediction results.
- Data Format
See sample_data dir for the train, dev, and test data format.
We did not provide a script for data preprocessing. You can follow our example data format to generate your own dataset.
# Data Format: tsv file with 4 columns:
1. target: 1
2. text: [s] cerebral arteries, atelectasis, localized swelling [e] .
3. entity_type: Drug
4. pat_id: id_1
Note:
1) the entity between [s][e] is the paragraph we used for text classification
2) in the test.tsv, you can set all labels to neg or False or whatever, because we will not use the label anyway
3) We recommend to evaluate the test performance in a separate process based on prediction. (see **post-processing**)
4) We recommend using official evaluation scripts to do evaluation to make sure the results reported are reliable.
- Special Tags
we use 2 special tags to identify entities
# The defaults tags we defined in the repo are
EN_START = "[s]"
EN_END = "[e]"
If you need to customize these tags, you can change them in config.py
- Visualization
LIME (Local Interpretable Model-agnostic Explanations) package is used in our package to visualize the important narrative features. For more detailed information, please refer to: https://github.com/marcotcr/lime
Note: Under the following file, 2 lines of code need to be modified before executing due to the special tags.
File path: envs/[envs_name]/lib/python3.9/site-packages/lime/lime_text.py.
Under function ‘def distance_fn(x):’, first change ‘sample = self.random_state.randint(1, doc_size + 1, num_samples - 1)’ to ‘sample = self.random_state.randint(1, doc_size + 1 - 6, num_samples - 1)’, then change ‘features_range = range(doc_size)’ to ‘features_range = range(3, doc_size - 3)’.
- Training/ Evaluation/ Prediction
Please refer to the wiki page for all details of the parameters flag details
highlight_index: accessing rows by index in test dataset; highlight_output_file: visualization output file path.
export CUDA_VISIBLE_DEVICES=1
data_dir=./sample_data
nmd=./new_modelzw
pof=./predictions.txt
log=./log.txt
log_highlight =./log_highlight.html
# NOTE: we have more options available, you can check our wiki for more information
python ./src/relation_extraction.py \
--model_type bert \
--data_format_mode 0 \
--classification_scheme 1 \
--pretrained_model UFNLP/gatortron-medium \
--data_dir $data_dir \
--new_model_dir $nmd \
--predict_output_file $pof \
--overwrite_model_dir \
--seed 13 \
--max_seq_length 512 \
--cache_data \
--do_train \
--do_eval \
--do_predict \
--do_predict_highlight \
--highlight_index 8\
--highlight_output_file $log_highlight\
--do_lower_case \
--train_batch_size 4 \
--eval_batch_size 4 \
--learning_rate 1e-5 \
--num_train_epochs 3 \
--gradient_accumulation_steps 1 \
--do_warmup \
--warmup_ratio 0.1 \
--weight_decay 0 \
--max_num_checkpoints 1 \
--log_file $log \
Here are some important output files that might be helpful for the analysis of the model's performance and learning progression:
_prob.tsv: contains the prediction results with two columns: the predicted target label and its corresponding probability
_.html: The visualization of keywords from selected text sample
epoch_loss.png: illustrates the average training and validation loss across each epoch
epoch_loss_metrics.png: showcases the validation accuracy, F1, precision, recall, and training loss for each epoch
Please cite our paper: https://arxiv.org/abs/2403.11425
@article{chen2024narrative,
title={Narrative Feature or Structured Feature? A Study of Large Language Models to Identify Cancer Patients at Risk of Heart Failure},
author={Chen, Ziyi and Zhang, Mengyuan and Ahmed, Mustafa Mohammed and Guo, Yi and George, Thomas J and Bian, Jiang and Wu, Yonghui},
journal={arXiv preprint arXiv:2403.11425},
year={2024}
}
Please contact us or post an issue if you have any questions.
- Ziyi Chen ([email protected])
- Yonghui Wu ([email protected])
We have a series transformer models pre-trained on MIMIC-III. You can find them here:
- https://transformer-models.s3.amazonaws.com/mimiciii_albert_10e_128b.zip
- https://transformer-models.s3.amazonaws.com/mimiciii_bert_10e_128b.zip
- https://transformer-models.s3.amazonaws.com/mimiciii_electra_5e_128b.zip
- https://transformer-models.s3.amazonaws.com/mimiciii_roberta_10e_128b.zip
- https://transformer-models.s3.amazonaws.com/mimiciii_xlnet_5e_128b.zip
- https://transformer-models.s3.amazonaws.com/mimiciii_deberta_10e_128b.tar.gz
- https://transformer-models.s3.amazonaws.com/mimiciii_longformer_5e_128b.zip