English | 简体中文
The torchkeras library is a simple tool for training neural network in pytorch jusk in a keras style. 😋😋
With torchkeras, You need not to write your training loop with many lines of code, all you need to do is just
like these two steps as below:
(i) create your network and wrap it and the loss_fn together with torchkeras.KerasModel like this:
model = torchkeras.KerasModel(net,loss_fn=nn.BCEWithLogitsLoss())
.
(ii) fit your model with the training data and validate data.
The main code of use torchkeras is like below.
import torch
import torchkeras
model = torchkeras.KerasModel(net,
loss_fn = nn.BCEWithLogitsLoss(),
optimizer= torch.optim.Adam(net.parameters(),lr = 0.001),
metrics_dict = {"acc":torchmetrics.Accuracy(task='binary')}
)
dfhistory=model.fit(train_data=dl_train,
val_data=dl_val,
epochs=20,
patience=3,
ckpt_path='checkpoint',
monitor="val_acc",
mode="max",
plot=True
)
Besides,You can use torchkeras.VLog to get the dynamic training visualization any where as you like ~
import time
import math,random
from torchkeras import VLog
epochs = 10
batchs = 30
#0, init vlog
vlog = VLog(epochs, monitor_metric='val_loss', monitor_mode='min')
#1, log_start
vlog.log_start()
for epoch in range(epochs):
#train
for step in range(batchs):
#2, log_step (for training step)
vlog.log_step({'train_loss':100-2.5*epoch+math.sin(2*step/batchs)})
time.sleep(0.05)
#eval
for step in range(20):
#3, log_step (for eval step)
vlog.log_step({'val_loss':100-2*epoch+math.sin(2*step/batchs)},training=False)
time.sleep(0.05)
#4, log_epoch
vlog.log_epoch({'val_loss':100 - 2*epoch+2*random.random()-1,
'train_loss':100-2.5*epoch+2*random.random()-1})
# 5, log_end
vlog.log_end()
This project seems somehow powerful, but the source code is very simple.
Actually, only about 200 lines of Python code.
If you want to understand or modify some details of this project, feel free to read and change the source code!!!
The main features supported by torchkeras are listed below.
Versions when these features are introduced and the libraries which they used or inspired from are given.
features | supported from version | used or inspired by library |
---|---|---|
✅ training progress bar | 3.0.0 | use tqdm,inspired by keras |
✅ training metrics | 3.0.0 | inspired by pytorch_lightning |
✅ notebook visualization in traning | 3.8.0 | inspired by fastai |
✅ early stopping | 3.0.0 | inspired by keras |
✅ gpu training | 3.0.0 | use accelerate |
✅ multi-gpus training(ddp) | 3.6.0 | use accelerate |
✅ fp16/bf16 training | 3.6.0 | use accelerate |
✅ tensorboard callback | 3.7.0 | use tensorboard |
✅ wandb callback | 3.7.0 | use wandb |
✅ VLog | 3.9.5 | use matplotlib |
You can follow these full examples to get started with torchkeras.
example | read notebook code | run example in kaggle |
---|---|---|
①kerasmodel basic 🔥🔥 | torchkeras.KerasModel example | |
②kerasmodel wandb 🔥🔥🔥 | torchkeras.KerasModel with wandb demo | |
③kerasmodel tunning 🔥🔥🔥 | torchkeras.KerasModel with wandb sweep demo | |
④kerasmodel tensorboard | torchkeras.KerasModel with tensorboard example | |
⑤kerasmodel ddp/tpu | torchkeras.KerasModel ddp tpu examples | |
⑥ VLog for lightgbm/ultralytics/transformers🔥🔥🔥 | VLog example |
In some using cases, because of the differences of the model input types, you need to rewrite the StepRunner of KerasModel. Here are some examples.
example | model library | notebook |
---|---|---|
RL | ||
ReinforcementLearning——Q-Learning🔥🔥 | - | Q-learning |
ReinforcementLearning——DQN | - | DQN |
Tabular | ||
BinaryClassification——LightGBM | - | LightGBM |
MultiClassification——FTTransformer🔥🔥🔥🔥🔥 | - | FTTransformer |
BinaryClassification——FM | - | FM |
BinaryClassification——DeepFM | - | DeepFM |
BinaryClassification——DeepCross | - | DeepCross |
CV | ||
ImageClassification——Resnet | - | Resnet |
ImageSegmentation——UNet | - | UNet |
ObjectDetection——SSD | - | SSD |
OCR——CRNN 🔥🔥 | - | CRNN-CTC |
ImageClassification——SwinTransformer | timm | Swin |
ObjectDetection——FasterRCNN | torchvision | FasterRCNN |
ImageSegmentation——DeepLabV3++ | segmentation_models_pytorch | Deeplabv3++ |
InstanceSegmentation——MaskRCNN | detectron2 | MaskRCNN |
ObjectDetection——YOLOv8 🔥🔥🔥 | ultralytics | YOLOv8 |
InstanceSegmentation——YOLOv8 🔥🔥🔥 | ultralytics | YOLOv8 |
NLP | ||
Seq2Seq——Transformer🔥🔥 | - | Transformer |
TextGeneration——Llama🔥 | - | Llama |
TextClassification——BERT | transformers | BERT |
TokenClassification——BERT | transformers | BERT_NER |
FinetuneLLM——ChatGLM2_LoRA 🔥🔥🔥 | transformers,peft | ChatGLM2_LoRA |
FinetuneLLM——ChatGLM2_AdaLoRA 🔥 | transformers,peft | ChatGLM2_AdaLoRA |
FinetuneLLM——ChatGLM2_QLoRA🔥 | transformers | ChatGLM2_QLoRA_Kaggle |
FinetuneLLM——BaiChuan13B_QLoRA🔥 | transformers | BaiChuan13B_QLoRA |
FinetuneLLM——BaiChuan13B_NER 🔥🔥🔥 | transformers | BaiChuan13B_NER |
FinetuneLLM——BaiChuan13B_MultiRounds 🔥 | transformers | BaiChuan13B_MultiRounds |
FinetuneLLM——Qwen7B_MultiRounds 🔥🔥🔥 | transformers | Qwen7B_MultiRounds |
FinetuneLLM——BaiChuan2_13B 🔥 | transformers | BaiChuan2_13B |
If you want to understand or modify some details of this project, feel free to read and change the source code!!!
Any other questions, you can contact the author form the wechat official account below:
算法美食屋