diff --git a/README.md b/README.md
index c8f39b0..c31ed66 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,6 @@
# FATE-LLM
-The repo for federated fine-tuning for LLM.
+FATE-LLM is a framework to support federated training with large language models, it also provides multiple parameter-efficient fine-tuning strategies for industrial applications.
+
+### Quick Start
+- [Federated ChatGLM-6B Training](./doc/tutorial/ChatGLM-6B.ipynb)
+- [GPT-2 Training](./doc/tutorial/GPT2-example.ipynb)
\ No newline at end of file
diff --git a/RELEASE.md b/RELEASE.md
new file mode 100644
index 0000000..2546f6b
--- /dev/null
+++ b/RELEASE.md
@@ -0,0 +1,4 @@
+## Release 1.1.0
+### Major Features and Improvements
+* Support Federated Training of ChatGLM-6B with parameter-efficient fine-tuning adapters: like Lora and P-Tuning V2 etc.
+* Integration of `peft`, which support many parameter-efficient adapters.
\ No newline at end of file
diff --git a/doc/tutorial/ChatGLM-6B.ipynb b/doc/tutorial/ChatGLM-6B.ipynb
new file mode 100644
index 0000000..a456566
--- /dev/null
+++ b/doc/tutorial/ChatGLM-6B.ipynb
@@ -0,0 +1,569 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Federated ChatGLM Tuning with Parameter Efficient methods in FATE-LLM"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, we will demonstrate how to efficiently train federated ChatGLM-6B using the FATE-LLM framework. In FATE-LLM, we introduce the \"pellm\"(Parameter Efficient Large Language Model) module, specifically designed for federated learning with large language models. We enable the implementation of parameter-efficient methods in federated learning, reducing communication overhead while maintaining model performance. In this tutorial we particularlly focus on ChatGLM-^b, and we will also emphasize the use of the Adapter mechanism for fine-tuning ChatGLM-6B, which enables us to effectively reduce communication volume and improve overall efficiency.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## ChatGLM-6B\n",
+ "\n",
+ "ChatGLM-6B is a large transformer-based language model with 6.2 billion parameters, trained on about 1T tokens of Chinese and English corpus. ChatGLM-6B is an open bilingual language model based on General Language Model. You can download the pretrained model from [here](https://huggingface.co/THUDM/chatglm-6b), or let the program automatically download it when you use it later."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Dataset: Advertising Text Generation\n",
+ "\n",
+ "This is an advertising test generateion dataset, you can download dataset from the following links: \n",
+ "- [data link 1](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view)\n",
+ "- [data link 2](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1)\n",
+ "and place it in the examples/data folder. \n",
+ "\n",
+ "You can refer to following link for more details about [data](https://aclanthology.org/D19-1321.pdf)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "df = pd.read_json('${fate_install}/examples/data/AdvertiseGen/train.json', lines=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " content | \n",
+ " summary | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤 | \n",
+ " 宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果... | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖 | \n",
+ " 圆形领口修饰脖颈线条,适合各种脸型,耐看有气质。无袖设计,尤显清凉,简约横条纹装饰,使得整身... | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 类型#上衣*版型#宽松*颜色#粉红色*图案#字母*图案#文字*图案#线条*衣样式#卫衣*衣款... | \n",
+ " 宽松的卫衣版型包裹着整个身材,宽大的衣身与身材形成鲜明的对比描绘出纤瘦的身形。下摆与袖口的不... | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 类型#裙*版型#宽松*材质#雪纺*风格#清新*裙型#a字*裙长#连衣裙 | \n",
+ " 踩着轻盈的步伐享受在午后的和煦风中,让放松与惬意感为你免去一身的压力与束缚,仿佛要将灵魂也寄... | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 类型#上衣*材质#棉*颜色#蓝色*风格#潮*衣样式#polo*衣领型#polo领*衣袖长#短... | \n",
+ " 想要在人群中脱颖而出吗?那么最适合您的莫过于这款polo衫短袖,采用了经典的polo领口和柔... | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 114594 | \n",
+ " 类型#上衣*风格#运动*风格#休闲*衣样式#外套*衣领型#立领*衣袖长#长袖*衣门襟#拉链*... | \n",
+ " 基础的外套廓形,直筒,立领长袖,中间金属拉链穿脱,方便实用,带有浓浓的休闲运动味道。日常休闲... | \n",
+ "
\n",
+ " \n",
+ " 114595 | \n",
+ " 类型#上衣*风格#街头*图案#创意*衣样式#卫衣 | \n",
+ " 在这件卫衣上,BRAND-white集合了女性化的柔美还有不变的街头风采,<UNK><UNK... | \n",
+ "
\n",
+ " \n",
+ " 114596 | \n",
+ " 类型#裙*版型#宽松*版型#显瘦*颜色#黑色*图案#撞色*裙型#直筒裙*裙款式#拼接 | \n",
+ " 采用简洁大体的黑色格调,宽松舒适的裙子内里,配上落肩的袖子拼接,不惧夏日的炎热,穿出清凉舒适... | \n",
+ "
\n",
+ " \n",
+ " 114597 | \n",
+ " 类型#上衣*颜色#黑色*颜色#紫色*风格#性感*图案#字母*图案#文字*图案#线条*图案#刺... | \n",
+ " 卫衣的短款长度设计能够适当地露出腰线,打造出纤瘦的身材十分性感。衣身的字母刺绣图案有着小巧的... | \n",
+ "
\n",
+ " \n",
+ " 114598 | \n",
+ " 类型#上衣*颜色#黑白*风格#简约*风格#休闲*图案#条纹*衣样式#风衣*衣样式#外套 | \n",
+ " 设计师以条纹作为风衣外套的主要设计元素,以简约点缀了外套,带来大气休闲的视觉效果。因为采用的... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
114599 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " content \\\n",
+ "0 类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤 \n",
+ "1 类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖 \n",
+ "2 类型#上衣*版型#宽松*颜色#粉红色*图案#字母*图案#文字*图案#线条*衣样式#卫衣*衣款... \n",
+ "3 类型#裙*版型#宽松*材质#雪纺*风格#清新*裙型#a字*裙长#连衣裙 \n",
+ "4 类型#上衣*材质#棉*颜色#蓝色*风格#潮*衣样式#polo*衣领型#polo领*衣袖长#短... \n",
+ "... ... \n",
+ "114594 类型#上衣*风格#运动*风格#休闲*衣样式#外套*衣领型#立领*衣袖长#长袖*衣门襟#拉链*... \n",
+ "114595 类型#上衣*风格#街头*图案#创意*衣样式#卫衣 \n",
+ "114596 类型#裙*版型#宽松*版型#显瘦*颜色#黑色*图案#撞色*裙型#直筒裙*裙款式#拼接 \n",
+ "114597 类型#上衣*颜色#黑色*颜色#紫色*风格#性感*图案#字母*图案#文字*图案#线条*图案#刺... \n",
+ "114598 类型#上衣*颜色#黑白*风格#简约*风格#休闲*图案#条纹*衣样式#风衣*衣样式#外套 \n",
+ "\n",
+ " summary \n",
+ "0 宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果... \n",
+ "1 圆形领口修饰脖颈线条,适合各种脸型,耐看有气质。无袖设计,尤显清凉,简约横条纹装饰,使得整身... \n",
+ "2 宽松的卫衣版型包裹着整个身材,宽大的衣身与身材形成鲜明的对比描绘出纤瘦的身形。下摆与袖口的不... \n",
+ "3 踩着轻盈的步伐享受在午后的和煦风中,让放松与惬意感为你免去一身的压力与束缚,仿佛要将灵魂也寄... \n",
+ "4 想要在人群中脱颖而出吗?那么最适合您的莫过于这款polo衫短袖,采用了经典的polo领口和柔... \n",
+ "... ... \n",
+ "114594 基础的外套廓形,直筒,立领长袖,中间金属拉链穿脱,方便实用,带有浓浓的休闲运动味道。日常休闲... \n",
+ "114595 在这件卫衣上,BRAND-white集合了女性化的柔美还有不变的街头风采,=v1.11.2 and deploy it with gpu machines. To running this code, make sure training data path is already binded. The following code shoud be copy to a script and run in a command line like \"python federated_chatglm.py\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can use this script to submit the model, but submitting the model will take a long time to train and generate a long log, so we won't do it here."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch as t\n",
+ "import os\n",
+ "from pipeline import fate_torch_hook\n",
+ "from pipeline.component import HomoNN\n",
+ "from pipeline.backend.pipeline import PipeLine\n",
+ "from pipeline.component import Reader\n",
+ "from pipeline.interface import Data\n",
+ "from pipeline.runtime.entity import JobParameters\n",
+ "\n",
+ "fate_torch_hook(t)\n",
+ "\n",
+ "\n",
+ "guest_0 = 9999\n",
+ "host_1 = 10000\n",
+ "pipeline = PipeLine().set_initiator(role='guest', party_id=guest_0).set_roles(guest=guest_0, host=host_1,\n",
+ " arbiter=guest_0)\n",
+ "data_guest = {\"name\": \"ad_guest\", \"namespace\": \"experiment\"}\n",
+ "data_host = {\"name\": \"ad_host\", \"namespace\": \"experiment\"}\n",
+ "guest_data_path = \"${fate_install}/examples/data/AdvertiseGen/train.json_guest\"\n",
+ "host_data_path = \"${fate_install}/examples/data/AdvertiseGen/train.json_host\"\n",
+ "# make sure the guest and host's training data are already binded. beforem\n",
+ "\n",
+ "reader_0 = Reader(name=\"reader_0\")\n",
+ "reader_0.get_party_instance(role='guest', party_id=guest_0).component_param(table=data_guest)\n",
+ "reader_0.get_party_instance(role='host', party_id=host_1).component_param(table=data_host)\n",
+ "\n",
+ "## Add your pretriained model path here, will load model&tokenizer from this path\n",
+ "\n",
+ "from peft import LoraConfig, TaskType\n",
+ "lora_config = LoraConfig(\n",
+ " task_type=TaskType.CAUSAL_LM,\n",
+ " inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
+ " target_modules=['query_key_value'],\n",
+ ")\n",
+ "ds_config = {\n",
+ " \"train_micro_batch_size_per_gpu\": 1,\n",
+ " \"optimizer\": {\n",
+ " \"type\": \"Adam\",\n",
+ " \"params\": {\n",
+ " \"lr\": 5e-4\n",
+ " }\n",
+ " },\n",
+ " \"fp16\": {\n",
+ " \"enabled\": True\n",
+ " },\n",
+ " \"zero_optimization\": {\n",
+ " \"stage\": 2,\n",
+ " \"allgather_partitions\": True,\n",
+ " \"allgather_bucket_size\": 5e8,\n",
+ " \"overlap_comm\": False,\n",
+ " \"reduce_scatter\": True,\n",
+ " \"reduce_bucket_size\": 5e8,\n",
+ " \"contiguous_gradients\": True\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "model_path = \"your download chatglm path\"\n",
+ "from pipeline.component.homo_nn import DatasetParam, TrainerParam\n",
+ "model = t.nn.Sequential(\n",
+ " t.nn.CustModel(module_name='pellm.chatglm', class_name='ChatGLMForConditionalGeneration',\n",
+ " peft_config=lora_config.to_dict(), peft_type='LoraConfig',\n",
+ " pretrained_path=model_path)\n",
+ ")\n",
+ "\n",
+ "# DatasetParam\n",
+ "dataset_param = DatasetParam(dataset_name='glm_tokenizer', text_max_length=64, tokenizer_name_or_path=model_path,\n",
+ " padding_side=\"left\")\n",
+ "# TrainerParam\n",
+ "trainer_param = TrainerParam(trainer_name='fedavg_trainer', epochs=5, batch_size=4, checkpoint_save_freqs=1, pin_memory=False, task_type=\"seq_2_seq_lm\",\n",
+ " data_loader_worker=8, secure_aggregate=False, save_to_local_dir=True, # pay attention to tihs parameter\n",
+ " collate_fn=\"DataCollatorForSeq2Seq\")\n",
+ "\n",
+ "\n",
+ "nn_component = HomoNN(name='nn_0', model=model , ds_config=ds_config)\n",
+ "\n",
+ "# set parameter for client 1\n",
+ "nn_component.get_party_instance(role='guest', party_id=guest_0).component_param(\n",
+ " dataset=dataset_param,\n",
+ " trainer=trainer_param,\n",
+ " torch_seed=100\n",
+ ")\n",
+ "\n",
+ "# set parameter for client 2\n",
+ "nn_component.get_party_instance(role='host', party_id=host_1).component_param(\n",
+ " dataset=dataset_param,\n",
+ " trainer=trainer_param,\n",
+ " torch_seed=100\n",
+ ")\n",
+ "\n",
+ "# set parameter for server\n",
+ "nn_component.get_party_instance(role='arbiter', party_id=guest_0).component_param(\n",
+ " trainer=trainer_param\n",
+ ")\n",
+ "\n",
+ "pipeline.add_component(reader_0)\n",
+ "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))\n",
+ "pipeline.compile()\n",
+ "\n",
+ "pipeline.fit(JobParameters(task_conf={\n",
+ " \"nn_0\": {\n",
+ " \"launcher\": \"deepspeed\",\n",
+ " \"world_size\": 8 # world_size means num of gpus to train in a single client\n",
+ " }\n",
+ "}))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Training With P-Tuning V2 Adapter"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To use another adapter lke P-Tuning V2, slightly changes is needed!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pipeline.component.homo_nn import DatasetParam, TrainerParam\n",
+ "model = t.nn.Sequential(\n",
+ " t.nn.CustModel(module_name='pellm.chatglm', class_name='ChatGLMForConditionalGeneration',\n",
+ " pre_seq_len=128, # only this parameters is needed\n",
+ " pretrained_path=model_path)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Models trained with FATE-LLM can be find under the directory `${fate_install}/fateflow/model/$jobids/$cpn_name/{model.pkl, checkpoint_xxx.pkl/adapter_model.bin}`, users must may sure \"save_to_local_dir=True\". \n",
+ "The following code is an example to load trained lora adapter weights:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import sys\n",
+ "import torch\n",
+ "from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model\n",
+ "from transformers import AutoModel, AutoTokenizer\n",
+ "\n",
+ "\n",
+ "def load_model(pretrained_model_path):\n",
+ " _tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path, trust_remote_code=True)\n",
+ " _model = AutoModel.from_pretrained(pretrained_model_path, trust_remote_code=True)\n",
+ "\n",
+ " _model = _model.half()\n",
+ " _model = _model.eval()\n",
+ "\n",
+ " return _model, _tokenizer\n",
+ "\n",
+ "\n",
+ "def load_data(data_path):\n",
+ " with open(data_path, \"r\") as fin:\n",
+ " for _l in fin:\n",
+ " yield json.loads(_l.strip())\n",
+ "\n",
+ "chatglm_model_path = \"\"\n",
+ "model, tokenizer = load_model(chatglm_model_path)\n",
+ "\n",
+ "test_data_path = \"{fate_install}/examples/data/AdvertiseGen/dev.json\"\n",
+ "dataset = load_data(test_data_path)\n",
+ "\n",
+ "peft_path = trained_model_path\n",
+ "peft_config = LoraConfig(\n",
+ " task_type=TaskType.CAUSAL_LM,\n",
+ " inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
+ " target_modules=['query_key_value'],\n",
+ ")\n",
+ "\n",
+ "model = get_peft_model(model, peft_config)\n",
+ "model.load_state_dict(torch.load(peft_path), strict=False)\n",
+ "model = model.half()\n",
+ "model.eval()\n",
+ "\n",
+ "for p in model.parameters():\n",
+ " if p.requires_grad:\n",
+ " print(p)\n",
+ "\n",
+ "model.cuda(\"cuda:0\")\n",
+ "\n",
+ "content = \"advertisement keywords\"\n",
+ "model.chat(tokenizer, content, do_sample=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/doc/tutorial/GPT2-example.ipynb b/doc/tutorial/GPT2-example.ipynb
new file mode 100644
index 0000000..31cba2f
--- /dev/null
+++ b/doc/tutorial/GPT2-example.ipynb
@@ -0,0 +1,671 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Federated GPT-2 Tuning with Parameter Efficient methods in FATE-LLM"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, we will demonstrate how to efficiently train federated large language models using the FATE-LLM framework. In FATE-LLM, we introduce the \"pellm\"(Parameter Efficient Large Language Model) module, specifically designed for federated learning with large language models. We enable the implementation of parameter-efficient methods in federated learning, reducing communication overhead while maintaining model performance. In this tutorial we particularlly focus on GPT-2, and we will also emphasize the use of the Adapter mechanism for fine-tuning GPT-2, which enables us to effectively reduce communication volume and improve overall efficiency.\n",
+ "\n",
+ "By following this tutorial, you will learn how to leverage the FATE-LLM framework to rapidly fine-tune federated large language models, such as GPT-2, with ease and efficiency."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## GPT2\n",
+ "\n",
+ "GPT-2 is a large transformer-based language model with 1.5 billion parameters, trained on a dataset of 8 million web pages. GPT-2 is trained with a causal language modeling (CLM) objective, conditioning on a left-to-right context window of 1024 tokens. In this tutorial, we will use GPT2, you can download the pretrained model from [here](https://huggingface.co/gpt2) (We choose the smallest version for this tutorial), or let the program automatically download it when you use it later."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Dataset: IMDB Sentimental\n",
+ "\n",
+ "In this section, we will introduce the process of preparing the IMDB dataset for use in our federated learning task. We use our tokenizer dataset(based on HuggingFace tokenizer) to preprocess the text data.\n",
+ "\n",
+ "About IMDB Sentimental Dataset:\n",
+ "\n",
+ "This is an binary classification dataset, you can download our processed dataset from here: \n",
+ "- https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/IMDB.csv\n",
+ "and place it in the examples/data folder. \n",
+ "\n",
+ "The orgin data is from: \n",
+ "- https://ai.stanford.edu/~amaas/data/sentiment/\n",
+ "\n",
+ "### Check Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "df = pd.read_csv('../../../examples/data/IMDB.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " text | \n",
+ " label | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " One of the other reviewers has mentioned that ... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " A wonderful little production. <br /><br />The... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 2 | \n",
+ " I thought this was a wonderful way to spend ti... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 3 | \n",
+ " Basically there's a family where a little boy ... | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 4 | \n",
+ " Petter Mattei's \"Love in the Time of Money\" is... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 1996 | \n",
+ " 1996 | \n",
+ " THE CELL (2000) Rating: 8/10<br /><br />The Ce... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1997 | \n",
+ " 1997 | \n",
+ " This movie, despite its list of B, C, and D li... | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 1998 | \n",
+ " 1998 | \n",
+ " I loved this movie! It was all I could do not ... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1999 | \n",
+ " 1999 | \n",
+ " This was the worst movie I have ever seen Bill... | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2000 | \n",
+ " 2000 | \n",
+ " Stranded in Space (1972) MST3K version - a ver... | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
2001 rows × 3 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id text label\n",
+ "0 0 One of the other reviewers has mentioned that ... 1\n",
+ "1 1 A wonderful little production.
The... 1\n",
+ "2 2 I thought this was a wonderful way to spend ti... 1\n",
+ "3 3 Basically there's a family where a little boy ... 0\n",
+ "4 4 Petter Mattei's \"Love in the Time of Money\" is... 1\n",
+ "... ... ... ...\n",
+ "1996 1996 THE CELL (2000) Rating: 8/10
The Ce... 1\n",
+ "1997 1997 This movie, despite its list of B, C, and D li... 0\n",
+ "1998 1998 I loved this movie! It was all I could do not ... 1\n",
+ "1999 1999 This was the worst movie I have ever seen Bill... 0\n",
+ "2000 2000 Stranded in Space (1972) MST3K version - a ver... 0\n",
+ "\n",
+ "[2001 rows x 3 columns]"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fate_llm.dataset.nlp_tokenizer import TokenizerDataset\n",
+ "\n",
+ "ds = TokenizerDataset(tokenizer_name_or_path=\"your model path\", text_max_length=128, \n",
+ " padding_side=\"left\", return_input_ids=False, pad_token='<|endoftext|>') # you can load tokenizer config from local pretrained tokenizer\n",
+ "\n",
+ "ds.load('../../../examples/data/IMDB.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "({'input_ids': tensor([ 3198, 286, 262, 584, 30702, 468, 4750, 326, 706, 4964,\n",
+ " 655, 352, 18024, 4471, 345, 1183, 307, 23373, 13, 1119,\n",
+ " 389, 826, 11, 355, 428, 318, 3446, 644, 3022, 351,\n",
+ " 502, 29847, 1671, 1220, 6927, 1671, 11037, 464, 717, 1517,\n",
+ " 326, 7425, 502, 546, 18024, 373, 663, 24557, 290, 42880,\n",
+ " 8589, 278, 8188, 286, 3685, 11, 543, 900, 287, 826,\n",
+ " 422, 262, 1573, 10351, 13, 9870, 502, 11, 428, 318,\n",
+ " 407, 257, 905, 329, 262, 18107, 2612, 276, 393, 44295,\n",
+ " 13, 770, 905, 16194, 645, 25495, 351, 13957, 284, 5010,\n",
+ " 11, 1714, 393, 3685, 13, 6363, 318, 22823, 11, 287,\n",
+ " 262, 6833, 779, 286, 262, 1573, 29847, 1671, 1220, 6927,\n",
+ " 1671, 11037, 1026, 318, 1444, 440, 57, 355, 326, 318,\n",
+ " 262, 21814, 1813, 284, 262, 34374, 22246, 4765]),\n",
+ " 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1])},\n",
+ " array([1.], dtype=float32))"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ds[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For more details of FATE-LLM dataset setting, we recommend that you read through these tutorials first: [NN Dataset Customization](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-your-Dataset.ipynbb), [Some Built-In Dataset](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Introduce-Built-In-Dataset.ipynb),"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## PELLM Model with Adapter\n",
+ "\n",
+ "In this section, we will guide you through the process of building a parameter-efficient language model using the FATE-LLM framework. We will focus on the implementation of the PELLM model and the integration of the Adapter mechanism, which enables efficient fine-tuning and reduces communication overhead in federated learning settings. Take GPT-2 as example you will learn how to leverage the FATE-LLM framework to rapidly develop and deploy a parameter-efficient language model using FATE-LLM built-in classes. Before starting this section, we recommend that you read through this tutorial first: [Model Customization](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-Model.ipynb)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### PELLM Models\n",
+ "\n",
+ "In this section we introduce the PELLM model, which is a parameter-efficient language model that can be used in federated learning settings. They are designed to be compatible with the FATE-LLM framework to enable federated model tuning/training.\n",
+ "\n",
+ "PELLM models are located at federatedml.nn.model_zoo.pellm(federatedml/nn/model_zoo/pellm):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "albert.py bert.py deberta.py gpt2.py\t\t\t __pycache__\r\n",
+ "bart.py chatglm.py distilbert.py parameter_efficient_llm.py roberta.py\r\n"
+ ]
+ }
+ ],
+ "source": [
+ "! ls ../../../fate/python/fate_llm/model_zoo/pellm"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can initialize your GPT2 model by loading the pretrained model from the model folder, or downloading the pretrained model from the Huggingface,\n",
+ "here we initialize the GPT2 model with the Lora Adapter, we will introduce Adapters in the following sub"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Adapters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can directly use adapters from the peft. See details for adapters on this page [Adapter Methods](https://huggingface.co/docs/peft/index) for more details. By specifying the adapter name and the adapter\n",
+ "config dict we can insert adapters into our language models:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from peft import LoraConfig, TaskType\n",
+ "\n",
+ "# define lora config\n",
+ "lora_config = LoraConfig(\n",
+ " task_type=TaskType.SEQ_CLS,\n",
+ " inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
+ " target_modules=['c_attn'],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Init PELLM Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fate_llm.model_zoo.pellm.gpt2 import GPT2\n",
+ "\n",
+ "# case 1 load pretrained weights from local pretrained weights, it is the same as using the huggingface pretrained model\n",
+ "path_to_pretrained_folder = 'your model path'\n",
+ "gpt2 = GPT2(pretrained_path=path_to_pretrained_folder, peft_type=\"LoraConfig\", peft_config=lora_config.to_dict(), num_labels=1, pad_token_id=50256)\n",
+ "\n",
+ "# case 2 directly download models from huggingface\n",
+ "# gpt2 = GPT2(pretrained_path=\"gpt2\", peft_type=\"LoraConfig\", peft_config=lora_config, num_labels=1, pad_token_id=50256)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this version we currently support these language model for federated training:\n",
+ "- ChatGLM\n",
+ "- Bert\n",
+ "- ALBert\n",
+ "- RoBerta\n",
+ "- GPT-2\n",
+ "- Bart\n",
+ "- DeBerta\n",
+ "- DistillBert"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**During the training process, all weights of the pretrained language model exclusive classifier head's weihgts will be frozen, and weights of adapters are traininable. Thus, FATE-LLM only train in the local training and aggregate adapters' weights and classifier head's weights(If has) in the fedederation process**\n",
+ "\n",
+ "Now available adapters are [Adapters Overview](https://huggingface.co/docs/peft/index) for details.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Use PELLM Model in FATE with CustModel\n",
+ "\n",
+ "In this [Model Customization](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-Model.ipynb) tutorial, we demonstrate how to employ the t.nn.CustomModel class in fate_torch to parse a model's structure and submit it to a federated learning task. The CustomModel automatically imports the model class from the model_zoo and initializes the models with the parameters provided. Since these language models are built-in, we can directly use them in the CustomModel and easily add a classifier head to address the classification task at hand:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch as t\n",
+ "from pipeline import fate_torch_hook\n",
+ "from pipeline.component.nn import save_to_fate_llm\n",
+ "fate_torch_hook(t)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%save_to_fate_llm model sigmoid.py\n",
+ "\n",
+ "import torch as t\n",
+ "\n",
+ "class Sigmoid(t.nn.Module):\n",
+ " \n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.sigmoid = t.nn.Sigmoid()\n",
+ " \n",
+ " def forward(self, x):\n",
+ " return self.sigmoid(x.logits)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# build CustModel with PELLM, and add a classifier head\n",
+ "from transformers import GPT2Config\n",
+ "\n",
+ "checkpoint_path = \"your model path\"\n",
+ "model = t.nn.Sequential(\n",
+ " t.nn.CustModel(module_name='pellm.gpt2', class_name='GPT2', \n",
+ " pretrained_path=checkpoint_path, peft_config=lora_config.to_dict(), peft_type=\"LoraConfig\", num_labels=1, pad_token_id=50256),\n",
+ " t.nn.CustModel(module_name='sigmoid', class_name='Sigmoid')\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Please note that during the training process, only trainable parameters will participate in the federated learning process."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Local Test\n",
+ "\n",
+ "Before submitting a federated learning task, we will demonstrate how to perform local testing to ensure the proper functionality of your custom dataset, model. We use the local mode of our FedAVGTrainer to test if our setting can run correctly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from fate_llm.model_zoo.pellm.gpt2 import GPT2\n",
+ "from fate_llm.model_zoo.sigmoid import Sigmoid\n",
+ "from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer\n",
+ "from transformers import GPT2Config\n",
+ "from fate_llm.dataset.nlp_tokenizer import TokenizerDataset\n",
+ "\n",
+ "# load dataset\n",
+ "ds = TokenizerDataset(tokenizer_name_or_path=\"your model path\", text_max_length=128, \n",
+ " padding_side=\"left\", return_input_ids=False, pad_token='<|endoftext|>') # you can load tokenizer config from local pretrained tokenizer\n",
+ "\n",
+ "ds.load('../../../examples/data/IMDB.csv')\n",
+ "\n",
+ "checkpoint_path = \"your model path\"\n",
+ "model = t.nn.Sequential(\n",
+ " GPT2(pretrained_path=checkpoint_path, peft_config=lora_config.to_dict(), peft_type=\"LoraConfig\", num_labels=1, pad_token_id=50256),\n",
+ " Sigmoid()\n",
+ ")\n",
+ "\n",
+ "trainer = FedAVGTrainer(epochs=1, batch_size=8, shuffle=True, data_loader_worker=8)\n",
+ "trainer.local_mode()\n",
+ "trainer.set_model(model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "epoch is 0\n",
+ "100%|██████████| 251/251 [04:39<00:00, 1.11s/it]\n",
+ "epoch loss is 0.5148034488660345\n"
+ ]
+ }
+ ],
+ "source": [
+ "opt = t.optim.Adam(model.parameters(), lr=0.001)\n",
+ "loss = t.nn.BCELoss()\n",
+ "# local test, here we only use CPU for training\n",
+ "trainer.train(ds, None, opt, loss)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Submit Federated Task\n",
+ "Once you have successfully completed local testing, We can submit a task to FATE. Please notice that this tutorial is ran on a standalone version. **Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to bind the data with the corresponding name&namespace on each machine.**\n",
+ "\n",
+ "In this example we load pretrained weights for gpt2 model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch as t\n",
+ "import os\n",
+ "from pipeline import fate_torch_hook\n",
+ "from pipeline.component import HomoNN\n",
+ "from pipeline.backend.pipeline import PipeLine\n",
+ "from pipeline.component import Reader\n",
+ "from pipeline.interface import Data\n",
+ "from transformers import GPT2Config\n",
+ "\n",
+ "fate_torch_hook(t)\n",
+ "\n",
+ "\n",
+ "fate_project_path = \"your model path\"\n",
+ "guest_0 = 9999\n",
+ "host_1 = 9999\n",
+ "pipeline = PipeLine().set_initiator(role='guest', party_id=guest_0).set_roles(guest=guest_0, host=host_1,\n",
+ " arbiter=guest_0)\n",
+ "data_0 = {\"name\": \"imdb\", \"namespace\": \"experiment\"}\n",
+ "data_path = fate_project_path + '/examples/data/IMDB.csv'\n",
+ "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)\n",
+ "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)\n",
+ "reader_0 = Reader(name=\"reader_0\")\n",
+ "reader_0.get_party_instance(role='guest', party_id=guest_0).component_param(table=data_0)\n",
+ "reader_0.get_party_instance(role='host', party_id=host_1).component_param(table=data_0)\n",
+ "\n",
+ "reader_1 = Reader(name=\"reader_1\")\n",
+ "reader_1.get_party_instance(role='guest', party_id=guest_0).component_param(table=data_0)\n",
+ "reader_1.get_party_instance(role='host', party_id=host_1).component_param(table=data_0)\n",
+ "\n",
+ "\n",
+ "## Add your pretriained model path here, will load model&tokenizer from this path\n",
+ "\n",
+ "\n",
+ "## LoraConfig\n",
+ "from peft import LoraConfig, TaskType\n",
+ "lora_config = LoraConfig(\n",
+ " task_type=TaskType.SEQ_CLS,\n",
+ " inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
+ " target_modules=['c_attn']\n",
+ ")\n",
+ "\n",
+ "\n",
+ "model_path = 'your model path'\n",
+ "model = t.nn.Sequential(\n",
+ " t.nn.CustModel(module_name='pellm.gpt2', class_name='GPT2', pretrained_path=model_path,\n",
+ " peft_config=lora_config.to_dict(), peft_type=\"LoraConfig\", num_labels=1, pad_token_id=50256),\n",
+ " t.nn.CustModel(module_name='sigmoid', class_name='Sigmoid')\n",
+ ")\n",
+ "\n",
+ "# DatasetParam\n",
+ "dataset_param = DatasetParam(dataset_name='nlp_tokenizer',text_max_length=128, tokenizer_name_or_path=model_path, \n",
+ " padding_side=\"left\", return_input_ids=False, pad_token='<|endoftext|>')\n",
+ "# TrainerParam\n",
+ "trainer_param = TrainerParam(trainer_name='fedavg_trainer', epochs=1, batch_size=8,\n",
+ " data_loader_worker=8, secure_aggregate=True)\n",
+ "\n",
+ "\n",
+ "nn_component = HomoNN(name='nn_0', model=model)\n",
+ "\n",
+ "# set parameter for client 1\n",
+ "nn_component.get_party_instance(role='guest', party_id=guest_0).component_param(\n",
+ " loss=t.nn.BCELoss(),\n",
+ " optimizer = t.optim.Adam(lr=0.0001, eps=1e-8),\n",
+ " dataset=dataset_param, \n",
+ " trainer=trainer_param,\n",
+ " torch_seed=100 \n",
+ ")\n",
+ "\n",
+ "# set parameter for client 2\n",
+ "nn_component.get_party_instance(role='host', party_id=host_1).component_param(\n",
+ " loss=t.nn.BCELoss(),\n",
+ " optimizer = t.optim.Adam(lr=0.0001, eps=1e-8),\n",
+ " dataset=dataset_param, \n",
+ " trainer=trainer_param,\n",
+ " torch_seed=100 \n",
+ ")\n",
+ "\n",
+ "# set parameter for server\n",
+ "nn_component.get_party_instance(role='arbiter', party_id=guest_0).component_param( \n",
+ " trainer=trainer_param\n",
+ ")\n",
+ "\n",
+ "pipeline.add_component(reader_0)\n",
+ "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))\n",
+ "pipeline.compile()\n",
+ "\n",
+ "pipeline.fit()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can use this script to submit the model, but submitting the model will take a long time to train and generate a long log, so we won't do it here."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training with CUDA\n",
+ "\n",
+ "You can use GPU by setting the cuda parameter of the FedAVGTrainer:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer_param = TrainerParam(trainer_name='fedavg_trainer', epochs=1, batch_size=8, \n",
+ " data_loader_worker=8, secure_aggregate=True, cuda=0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The cuda parameter here accepts an integer value that corresponds to the index of the GPU you want to use for training. \n",
+ "In the example above, the value is set to 0, which means that on every client the first available GPU in the system will be used. \n",
+ "If you have multiple GPUs and would like to use a specific one, simply change the value of the cuda parameter to the appropriate index."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this example, client_0 is set to use GPUs with indices [0, 1, 2, 3], while client_1 uses GPUs with indices [0, 3, 4]. The server does not support GPUs usage in the aggregation procession"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/python/fate_llm/__init__.py b/python/fate_llm/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/python/fate_llm/dataset/__init__.py b/python/fate_llm/dataset/__init__.py
new file mode 100644
index 0000000..ef471ba
--- /dev/null
+++ b/python/fate_llm/dataset/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
\ No newline at end of file
diff --git a/python/fate_llm/dataset/glm_tokenizer.py b/python/fate_llm/dataset/glm_tokenizer.py
new file mode 100644
index 0000000..17970f7
--- /dev/null
+++ b/python/fate_llm/dataset/glm_tokenizer.py
@@ -0,0 +1,88 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from federatedml.nn.dataset.base import Dataset
+import pandas as pd
+from transformers import AutoTokenizer
+
+
+PROMPT_TEMPLATE = "{prompt}"
+
+
+class GLMTokenizerDataset(Dataset):
+ def __init__(self, truncation=True, text_max_length=256,
+ tokenizer_name_or_path=None,
+ padding=True, padding_side="right", pad_token=None,
+ trust_remote_code=True,
+ prompt_template=None,
+ prompt_column="content",
+ response_column="summary"
+ ):
+
+ super(GLMTokenizerDataset, self).__init__()
+ self.label = None
+ self.tokenizer = None
+ self.padding = padding
+ self.truncation = truncation
+ self.max_length = text_max_length
+ self.tokenizer_name_or_path = tokenizer_name_or_path
+ self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=trust_remote_code)
+ self.tokenizer.padding_side = padding_side
+ if pad_token is not None:
+ self.tokenizer.add_special_tokens({'pad_token': pad_token})
+
+ self.prompt_template = prompt_template if prompt_template else PROMPT_TEMPLATE
+ self.prompt_column = prompt_column
+ self.response_column = response_column
+ self._data = None
+
+ def load(self, file_path):
+ df = pd.read_json(file_path, lines=True)
+ self._data = df.apply(self._process_data, axis=1)
+
+ def _process_data(self, line):
+ _prompt = line[self.prompt_column]
+ _response = line[self.response_column]
+
+ prompt = self.prompt_template.format_map(dict(prompt=_prompt))
+ prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
+ target_ids = self.tokenizer.encode(_response, add_special_tokens=False)
+
+ if len(prompt_ids) > self.max_length - 1:
+ prompt_ids = prompt_ids[: self.max_length - 1]
+ if len(target_ids) > self.max_length - 2:
+ target_ids = target_ids[: self.max_length - 2]
+
+ input_ids = self.tokenizer.build_inputs_with_special_tokens(prompt_ids, target_ids)
+
+ seq_length = input_ids.index(self.tokenizer.bos_token_id)
+ labels = [-100] * seq_length + input_ids[seq_length:]
+
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ }
+
+ def get_vocab_size(self):
+ return self.tokenizer.vocab_size
+
+ def __getitem__(self, item):
+ return self._data[item]
+
+ def __len__(self):
+ return len(self._data)
+
+ def __repr__(self):
+ return self.tokenizer.__repr__()
diff --git a/python/fate_llm/dataset/nlp_tokenizer.py b/python/fate_llm/dataset/nlp_tokenizer.py
new file mode 100644
index 0000000..79c81bb
--- /dev/null
+++ b/python/fate_llm/dataset/nlp_tokenizer.py
@@ -0,0 +1,116 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from federatedml.nn.dataset.base import Dataset
+import pandas as pd
+import torch as t
+from transformers import AutoTokenizer
+import os
+import numpy as np
+
+# avoid tokenizer parallelism
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+class TokenizerDataset(Dataset):
+ """
+ A Dataset for some basic NLP Tasks, this dataset will automatically transform raw text into word indices
+ using AutoTokenizer from transformers library,
+
+ Parameters
+ ----------
+ truncation bool, truncate word sequence to 'text_max_length'
+ text_max_length int, max length of word sequences
+ tokenizer_name_or_path str, name of bert tokenizer(see transformers official for details) or path to local
+ transformer tokenizer folder
+ return_label bool, return label or not, this option is for host dataset, when running hetero-NN
+ padding bool, whether to pad the word sequence to 'text_max_length'
+ padding_side str, 'left' or 'right', where to pad the word sequence
+ pad_token str, pad token, use this str as pad token, if None, use tokenizer.pad_token
+ return_input_ids bool, whether to return input_ids or not, if False, return word_idx['input_ids']
+ """
+
+ def __init__(self, truncation=True, text_max_length=128,
+ tokenizer_name_or_path="bert-base-uncased",
+ return_label=True, padding=True, padding_side="right", pad_token=None,
+ return_input_ids=True
+ ):
+
+ super(TokenizerDataset, self).__init__()
+ self.text = None
+ self.word_idx = None
+ self.label = None
+ self.tokenizer = None
+ self.sample_ids = None
+ self.padding = padding
+ self.truncation = truncation
+ self.max_length = text_max_length
+ self.with_label = return_label
+ self.tokenizer_name_or_path = tokenizer_name_or_path
+ self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path)
+ self.tokenizer.padding_side = padding_side
+ self.return_input_ids = return_input_ids
+ if pad_token is not None:
+ self.tokenizer.add_special_tokens({'pad_token': pad_token})
+
+ def load(self, file_path):
+
+ tokenizer = self.tokenizer
+ self.text = pd.read_csv(file_path)
+ text_list = list(self.text.text)
+
+ self.word_idx = tokenizer(
+ text_list,
+ padding=self.padding,
+ return_tensors='pt',
+ truncation=self.truncation,
+ max_length=self.max_length)
+
+ if self.return_input_ids:
+ self.word_idx = self.word_idx['input_ids']
+
+ if self.with_label:
+ self.label = t.Tensor(self.text.label).detach().numpy()
+ self.label = self.label.reshape((len(self.text), -1))
+
+ if 'id' in self.text:
+ self.sample_ids = self.text['id'].values.tolist()
+
+ def get_classes(self):
+ return np.unique(self.label).tolist()
+
+ def get_vocab_size(self):
+ return self.tokenizer.vocab_size
+
+ def get_sample_ids(self):
+ return self.sample_ids
+
+ def __getitem__(self, item):
+
+ if self.return_input_ids:
+ ret = self.word_idx[item]
+ else:
+ ret = {k: v[item] for k, v in self.word_idx.items()}
+
+ if self.with_label:
+ return ret, self.label[item]
+
+ return ret
+
+ def __len__(self):
+ return len(self.text)
+
+ def __repr__(self):
+ return self.tokenizer.__repr__()
diff --git a/python/fate_llm/model_zoo/__init__.py b/python/fate_llm/model_zoo/__init__.py
new file mode 100644
index 0000000..ef471ba
--- /dev/null
+++ b/python/fate_llm/model_zoo/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
\ No newline at end of file
diff --git a/python/fate_llm/model_zoo/pellm/albert.py b/python/fate_llm/model_zoo/pellm/albert.py
new file mode 100644
index 0000000..24a9ddd
--- /dev/null
+++ b/python/fate_llm/model_zoo/pellm/albert.py
@@ -0,0 +1,44 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from transformers import AlbertConfig, AutoConfig
+from transformers import AlbertForSequenceClassification
+from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
+
+
+class Albert(PELLM):
+
+ config_class = AlbertConfig
+ model_loader = AlbertForSequenceClassification
+
+ def __init__(self, config: dict = None,
+ pretrained_path: str = None,
+ peft_type: str = None,
+ peft_config: dict = None,
+ **kwargs
+ ) -> None:
+
+ if pretrained_path is not None:
+ self.check_config(pretain_path=pretrained_path)
+ if config is None and pretrained_path is None:
+ config = AlbertConfig().to_dict() # use default model setting
+ super().__init__(config=config, pretrained_path=pretrained_path,
+ peft_type=peft_type, peft_config=peft_config, **kwargs)
+
+ def check_config(self, pretain_path):
+ config = AutoConfig.from_pretrained(pretain_path)
+ assert isinstance(
+ config, AlbertConfig), 'The config of pretrained model must be AlbertConfig, but got {}'.format(
+ type(config))
diff --git a/python/fate_llm/model_zoo/pellm/bart.py b/python/fate_llm/model_zoo/pellm/bart.py
new file mode 100644
index 0000000..d401bfb
--- /dev/null
+++ b/python/fate_llm/model_zoo/pellm/bart.py
@@ -0,0 +1,42 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from transformers import BartConfig, AutoConfig
+from transformers import BartForSequenceClassification
+from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
+
+
+class Bart(PELLM):
+ config_class = BartConfig
+ model_loader = BartForSequenceClassification
+
+ def __init__(self, config: dict = None,
+ pretrained_path: str = None,
+ peft_type: str = None,
+ peft_config: dict = None,
+ **kwargs) -> None:
+
+ if pretrained_path is not None:
+ self.check_config(pretrain_path=pretrained_path)
+ if config is None and pretrained_path is None:
+ config = BartConfig().to_dict()
+ super().__init__(config=config, pretrained_path=pretrained_path,
+ peft_type=peft_type, peft_config=peft_config, **kwargs)
+
+ def check_config(self, pretrain_path):
+ config = AutoConfig.from_pretrained(pretrain_path)
+ assert isinstance(
+ config, BartConfig), 'The config of pretrained model must be BartConfig, but got {}'.format(
+ type(config))
diff --git a/python/fate_llm/model_zoo/pellm/bert.py b/python/fate_llm/model_zoo/pellm/bert.py
new file mode 100644
index 0000000..c95bc92
--- /dev/null
+++ b/python/fate_llm/model_zoo/pellm/bert.py
@@ -0,0 +1,42 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from transformers import BertConfig, AutoConfig
+from transformers import BertForSequenceClassification
+from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
+
+
+class Bert(PELLM):
+ config_class = BertConfig
+ model_loader = BertForSequenceClassification
+
+ def __init__(self, config: dict = None,
+ pretrained_path: str = None,
+ peft_type: str = None,
+ peft_config: dict = None,
+ **kwargs) -> None:
+
+ if pretrained_path is not None:
+ self.check_config(pretrain_path=pretrained_path)
+ if config is None and pretrained_path is None:
+ config = BertConfig().to_dict()
+ super().__init__(config=config, pretrained_path=pretrained_path,
+ peft_type=peft_type, peft_config=peft_config, **kwargs)
+
+ def check_config(self, pretrain_path):
+ config = AutoConfig.from_pretrained(pretrain_path)
+ assert isinstance(
+ config, BertConfig), 'The config of pretrained model must be BertConfig, but got {}'.format(
+ type(config))
diff --git a/python/fate_llm/model_zoo/pellm/chatglm.py b/python/fate_llm/model_zoo/pellm/chatglm.py
new file mode 100644
index 0000000..8648cf4
--- /dev/null
+++ b/python/fate_llm/model_zoo/pellm/chatglm.py
@@ -0,0 +1,54 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
+from transformers import AutoConfig
+
+
+class ChatGLMForConditionalGeneration(PELLM):
+ enable_save_pretrained = True
+
+ def __init__(self,
+ pretrained_path: str = None,
+ peft_type: str = None,
+ peft_config: dict = None,
+ fp16: bool = True,
+ pre_seq_len: int = None,
+ prefix_projection: bool = False) -> None:
+
+ self.pre_seq_len = pre_seq_len
+ self.prefix_projection = prefix_projection
+ self.fp16 = fp16
+
+ super().__init__(pretrained_path=pretrained_path,
+ peft_type=peft_type,
+ peft_config=peft_config)
+
+ def init_config(self):
+ self.config = AutoConfig.from_pretrained(self.config_path, trust_remote_code=True)
+ self.config.pre_seq_len = self.pre_seq_len
+ self.config.prefix_projection = self.prefix_projection
+
+ def init_base_lm(self):
+ super(ChatGLMForConditionalGeneration, self).init_base_lm(trust_remote_code=True)
+ if self.fp16:
+ self._pe_lm.half()
+
+ def add_peft(self):
+ if self.pre_seq_len:
+ self._pe_lm.half()
+ self._pe_lm.transformer.prefix_encoder.float()
+ else:
+ super(ChatGLMForConditionalGeneration, self).add_peft()
diff --git a/python/fate_llm/model_zoo/pellm/deberta.py b/python/fate_llm/model_zoo/pellm/deberta.py
new file mode 100644
index 0000000..376dcb2
--- /dev/null
+++ b/python/fate_llm/model_zoo/pellm/deberta.py
@@ -0,0 +1,43 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from transformers import DebertaConfig, AutoConfig
+from transformers import DebertaForSequenceClassification
+from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
+
+
+class Deberta(PELLM):
+
+ config_class = DebertaConfig
+ model_loader = DebertaForSequenceClassification
+
+ def __init__(self, config: dict = None,
+ pretrained_path: str = None,
+ peft_type: str = None,
+ peft_config: dict = None,
+ **kwargs) -> None:
+
+ if pretrained_path is not None:
+ self.check_config(pretrain_path=pretrained_path)
+ if config is None and pretrained_path is None:
+ config = DebertaConfig().to_dict()
+ super().__init__(config=config, pretrained_path=pretrained_path,
+ peft_type=peft_type, peft_config=peft_config, **kwargs)
+
+ def check_config(self, pretrain_path):
+ config = AutoConfig.from_pretrained(pretrain_path)
+ assert isinstance(
+ config, DebertaConfig), 'The config of pretrained model must be DebertaConfig, but got {}'.format(
+ type(config))
diff --git a/python/fate_llm/model_zoo/pellm/distilbert.py b/python/fate_llm/model_zoo/pellm/distilbert.py
new file mode 100644
index 0000000..c23e62f
--- /dev/null
+++ b/python/fate_llm/model_zoo/pellm/distilbert.py
@@ -0,0 +1,42 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from transformers import DistilBertConfig, AutoConfig
+from transformers import DistilBertForSequenceClassification
+from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
+
+
+class DistilBert(PELLM):
+ config_class = DistilBertConfig
+ model_loader = DistilBertForSequenceClassification
+
+ def __init__(self, config: dict = None,
+ pretrained_path: str = None,
+ peft_type: str = None,
+ peft_config: dict = None,
+ **kwargs) -> None:
+
+ if pretrained_path is not None:
+ self.check_config(pretrain_path=pretrained_path)
+ if config is None and pretrained_path is None:
+ config = DistilBertConfig().to_dict()
+ super().__init__(config=config, pretrained_path=pretrained_path,
+ peft_type=peft_type, peft_config=peft_config, **kwargs)
+
+ def check_config(self, pretrain_path):
+ config = AutoConfig.from_pretrained(pretrain_path)
+ assert isinstance(
+ config, DistilBertConfig), 'The config of pretrained model must be DistilBertConfig, but got {}'.format(
+ type(config))
diff --git a/python/fate_llm/model_zoo/pellm/gpt2.py b/python/fate_llm/model_zoo/pellm/gpt2.py
new file mode 100644
index 0000000..dcfa036
--- /dev/null
+++ b/python/fate_llm/model_zoo/pellm/gpt2.py
@@ -0,0 +1,42 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from transformers import GPT2Config, AutoConfig
+from transformers import GPT2ForSequenceClassification
+from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
+
+
+class GPT2(PELLM):
+ config_class = GPT2Config
+ model_loader = GPT2ForSequenceClassification
+
+ def __init__(self, config: dict = None,
+ pretrained_path: str = None,
+ peft_type: str = None,
+ peft_config: dict = None,
+ **kwargs) -> None:
+
+ if pretrained_path is not None:
+ self.check_config(pretrain_path=pretrained_path)
+ if config is None and pretrained_path is None:
+ config = GPT2Config().to_dict()
+ super().__init__(config=config, pretrained_path=pretrained_path,
+ peft_type=peft_type, peft_config=peft_config, **kwargs)
+
+ def check_config(self, pretrain_path):
+ config = AutoConfig.from_pretrained(pretrain_path)
+ assert isinstance(
+ config, GPT2Config), 'The config of pretrained model must be GPT2Config, but got {}'.format(
+ type(config))
diff --git a/python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py b/python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py
new file mode 100644
index 0000000..a120ce2
--- /dev/null
+++ b/python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py
@@ -0,0 +1,135 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import peft
+import torch
+from peft import PeftModel
+from transformers import AutoConfig
+from transformers import AutoModel
+from transformers.configuration_utils import PretrainedConfig
+from federatedml.util import LOGGER
+
+
+AVAILABLE_PEFT_CONFIG = list(
+ filter(
+ lambda peft_type: peft_type.endswith("Config"), dir(peft)
+ )
+)
+
+
+class PELLM(torch.nn.Module):
+
+ config_class: PretrainedConfig = None
+ enable_save_pretrained: bool = True
+ model_loader = None
+
+ def __init__(self, config: dict = None,
+ pretrained_path: str = None,
+ peft_type: str = None,
+ peft_config: dict = None,
+ **kwargs
+ ) -> None:
+
+ super().__init__()
+ self._pe_lm: PeftModel = None
+ self.config = config
+ self.config_path = pretrained_path
+ self.peft_type = peft_type
+ self.peft_config = peft_config
+
+ assert self.config_path is not None or self.config is not None, \
+ "At least one of config_path and config must be set."
+ self._init_pelm(**kwargs)
+
+ def _init_pelm(self, **kwargs):
+ self.init_lm_with_peft(**kwargs)
+ self.model_summary()
+
+ def init_lm_with_peft(self, **kwargs):
+ self.init_config(**kwargs)
+ self.init_base_lm()
+ self.add_peft()
+
+ def init_config(self, **kwargs):
+ if self.config_path is not None:
+ self.config = AutoConfig.from_pretrained(self.config_path)
+ elif self.config is not None and self.config_class is not None:
+ self.config = self.config_class().from_dict(self.config)
+ else:
+ raise ValueError(
+ 'config_path to pretrained model folder and model config dict cannot be None at the same time, '
+ 'you need to specify one of them')
+
+ if kwargs:
+ self.config.update(kwargs)
+
+ def init_base_lm(self, **kwargs):
+ model_loader = self.model_loader if self.model_loader is not None else AutoModel
+ if self.config is not None:
+ self._pe_lm = model_loader.from_pretrained(self.config_path, config=self.config, **kwargs)
+ elif self.config_path is not None:
+ self._pe_lm = model_loader.from_pretrained(self.config_path, **kwargs)
+ else:
+ raise ValueError(
+ 'config_path to pretrained model folder cannot be None')
+
+ def add_peft(self):
+ assert self.peft_type in AVAILABLE_PEFT_CONFIG, 'peft name {} not in availabe config {}'.format(
+ self.peft_type, AVAILABLE_PEFT_CONFIG)
+
+ if self.peft_config is None:
+ peft_config = getattr(peft, self.peft_type)()
+ else:
+ peft_config = getattr(peft, self.peft_type)(**self.peft_config)
+
+ self._pe_lm = peft.get_peft_model(self._pe_lm, peft_config)
+
+ def model_summary(self):
+ try:
+ summary = self._pe_lm.print_trainable_parameters()
+
+ LOGGER.debug('PELLM model summary: \n{}'.format(summary))
+ except BaseException:
+ pass
+
+ def _get_trainable_parameters(self):
+ trainable = []
+ for n, p in self._pe_lm.named_parameters():
+ if p.requires_grad:
+ trainable.append(p)
+ return trainable
+
+ def forward(self, tokenized_data: dict):
+ return self._pe_lm(**tokenized_data)
+
+ def save_pretrained(self, path):
+ if not self.enable_save_pretrained:
+ raise ValueError("To save trainable parameters only, set enable_save_pretrained=True in your model")
+
+ from pathlib import Path
+
+ state_dict = {
+ k: p.to("cpu") for k, p in self._pe_lm.named_parameters() if p.requires_grad
+ }
+ Path.mkdir(Path(path), exist_ok=True)
+ torch.save(state_dict, Path(path).joinpath("adapter_model.bin"))
+
+
+class AutoPELLM(PELLM):
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+
diff --git a/python/fate_llm/model_zoo/pellm/roberta.py b/python/fate_llm/model_zoo/pellm/roberta.py
new file mode 100644
index 0000000..33d1079
--- /dev/null
+++ b/python/fate_llm/model_zoo/pellm/roberta.py
@@ -0,0 +1,42 @@
+#
+# Copyright 2019 The FATE Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from transformers import RobertaConfig, AutoConfig
+from transformers import RobertaForSequenceClassification
+from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
+
+
+class Roberta(PELLM):
+ config_class = RobertaConfig
+ model_loader = RobertaForSequenceClassification
+
+ def __init__(self, config: dict = None,
+ pretrained_path: str = None,
+ peft_type: str = None,
+ peft_config: dict = None,
+ **kwargs) -> None:
+
+ if pretrained_path is not None:
+ self.check_config(pretrain_path=pretrained_path)
+ if config is None and pretrained_path is None:
+ config = RobertaConfig().to_dict()
+ super().__init__(config=config, pretrained_path=pretrained_path,
+ peft_type=peft_type, peft_config=peft_config, **kwargs)
+
+ def check_config(self, pretrain_path):
+ config = AutoConfig.from_pretrained(pretrain_path)
+ assert isinstance(
+ config, RobertaConfig), 'The config of pretrained model must be RobertaConfig, but got {}'.format(
+ type(config))