-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
886d26d
commit 3b047f7
Showing
3 changed files
with
348 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,346 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4d252bab-f488-489f-a3a3-a8869cf99dcb", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"1. 安装最新版本的modelscope和swift" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "68bbc6b0-e4e5-4663-a28e-fec1e667427b", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install modelscope ms-swift -U\n", | ||
"!pip install tf-keras==2.16.0 --no-dependencies" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6bc13d59-415b-4455-a965-49cead0904f7", | ||
"metadata": {}, | ||
"source": [ | ||
"2. 加载数据集" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "9a292aab-04d2-4b68-8507-53f1ea22bb9f", | ||
"metadata": { | ||
"execution": { | ||
"iopub.execute_input": "2024-11-01T08:24:57.831096Z", | ||
"iopub.status.busy": "2024-11-01T08:24:57.830826Z", | ||
"iopub.status.idle": "2024-11-01T08:25:07.688996Z", | ||
"shell.execute_reply": "2024-11-01T08:25:07.688492Z", | ||
"shell.execute_reply.started": "2024-11-01T08:24:57.831080Z" | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from modelscope import MsDataset\n", | ||
"dataset = MsDataset.load('swift/classical_chinese_translate')\n", | ||
"dataset = dataset['train']" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "fe841e73-3f85-4a4d-879c-0b9064fe9e98", | ||
"metadata": {}, | ||
"source": [ | ||
"3. 查看数据集内容" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "817f4c82-cf51-4504-8866-09eff41fe732", | ||
"metadata": { | ||
"ExecutionIndicator": { | ||
"show": true | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"print(dataset)\n", | ||
"print(dataset['conversations'][0])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ac948dfd-f290-43c4-bf4f-7ece4883ccad", | ||
"metadata": {}, | ||
"source": [ | ||
"4. 加载模型、分词器、模板" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "231ff476-df34-470b-8b73-4bd24e2f298a", | ||
"metadata": { | ||
"ExecutionIndicator": { | ||
"show": true | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from modelscope import AutoModelForCausalLM\n", | ||
"from modelscope import AutoTokenizer\n", | ||
"from swift.llm import get_template, TemplateType\n", | ||
"tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-7B-Instruct', trust_remote_code=True)\n", | ||
"model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-7B-Instruct', torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)\n", | ||
"print(model)\n", | ||
"print(tokenizer('I like AI'))\n", | ||
"\n", | ||
"\n", | ||
"template = get_template(TemplateType.qwen2_5, tokenizer, max_length=400)\n", | ||
"ret = template.encode({'query': 'what is your hobby?', 'response': 'I like AI'})\n", | ||
"print(ret)\n", | ||
"tokenizer.decode(ret[0]['input_ids'])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4dd66f8b-921c-4433-be8d-122bc182920c", | ||
"metadata": {}, | ||
"source": [ | ||
"5. 预处理数据集" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f41eebe2-7009-41cf-a936-5d4b471daba4", | ||
"metadata": { | ||
"ExecutionIndicator": { | ||
"show": true | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from swift.llm.utils.preprocess import ConversationsPreprocessor\n", | ||
"ds = ConversationsPreprocessor()(dataset)\n", | ||
"print(ds[0])\n", | ||
"\n", | ||
"def encode(example):\n", | ||
" example, kwargs = template.encode(example)\n", | ||
" if 'input_ids' not in example:\n", | ||
" return {\n", | ||
" 'input_ids': None,\n", | ||
" 'labels': None,\n", | ||
" }\n", | ||
" return example\n", | ||
"\n", | ||
"ds = ds.select(range(300)).map(encode).filter(lambda e: e.get('input_ids'))\n", | ||
"ds = ds.train_test_split(test_size=0.1)\n", | ||
"\n", | ||
"train_dataset, val_dataset = ds['train'], ds['test']\n", | ||
"print('===========================================')\n", | ||
"print(train_dataset[0])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8eb90b4e-7e65-48d7-9f62-d9b226918618", | ||
"metadata": {}, | ||
"source": [ | ||
"6. 加载LoRA" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "9cf58c93-a3e6-4672-80d9-06acc25bd9e8", | ||
"metadata": { | ||
"ExecutionIndicator": { | ||
"show": true | ||
}, | ||
"execution": { | ||
"iopub.execute_input": "2024-11-01T08:25:51.704754Z", | ||
"iopub.status.busy": "2024-11-01T08:25:51.704328Z", | ||
"iopub.status.idle": "2024-11-01T08:25:52.427483Z", | ||
"shell.execute_reply": "2024-11-01T08:25:52.426958Z", | ||
"shell.execute_reply.started": "2024-11-01T08:25:51.704723Z" | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from swift import Swift, LoraConfig\n", | ||
"\n", | ||
"\n", | ||
"lora_config = LoraConfig(\n", | ||
" r=8,\n", | ||
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n", | ||
" lora_alpha=32,\n", | ||
" lora_dropout=0.05)\n", | ||
"model = Swift.prepare_model(model, lora_config)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3a06c9e0-24ff-49a7-b901-54d800248407", | ||
"metadata": {}, | ||
"source": [ | ||
"7. 训练" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "add78662-d30f-48b7-af7c-358840af1433", | ||
"metadata": { | ||
"ExecutionIndicator": { | ||
"show": true | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# A100 18G memory\n", | ||
"from swift import Seq2SeqTrainer, Seq2SeqTrainingArguments\n", | ||
"import torch\n", | ||
"\n", | ||
"\n", | ||
"train_args = Seq2SeqTrainingArguments(\n", | ||
" output_dir='output',\n", | ||
" learning_rate=1e-4,\n", | ||
" num_train_epochs=1,\n", | ||
" eval_steps=5,\n", | ||
" save_steps=5,\n", | ||
" evaluation_strategy='no',\n", | ||
" save_strategy='steps',\n", | ||
" dataloader_num_workers=4,\n", | ||
" per_device_train_batch_size=1,\n", | ||
" gradient_accumulation_steps=16,\n", | ||
" logging_steps=2,\n", | ||
")\n", | ||
"\n", | ||
"print(train_dataset[0])\n", | ||
"\n", | ||
"trainer = Seq2SeqTrainer(\n", | ||
" model=model,\n", | ||
" args=train_args,\n", | ||
" data_collator=template.data_collator,\n", | ||
" train_dataset=train_dataset,\n", | ||
" eval_dataset=val_dataset,\n", | ||
" tokenizer=tokenizer)\n", | ||
"\n", | ||
"trainer.train()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "741ed1fd-1623-4221-a7c3-efdd37bc57da", | ||
"metadata": {}, | ||
"source": [ | ||
"8. 看看存了什么" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7e1a5e23-632d-4487-8a5c-fcae84311f51", | ||
"metadata": { | ||
"ExecutionIndicator": { | ||
"show": true | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!ls output/checkpoint-11" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "799089b3-5458-4b54-9865-131fca081b26", | ||
"metadata": {}, | ||
"source": [ | ||
"9. 推理" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "00f72b82-e856-48cd-9be5-38afee347b48", | ||
"metadata": { | ||
"ExecutionIndicator": { | ||
"show": true | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from modelscope import GenerationConfig\n", | ||
"query = '你喜欢什么'\n", | ||
"inputs, kwargs = template.encode({'query': query})\n", | ||
"print(inputs)\n", | ||
"generation_config = GenerationConfig(max_new_tokens=512, top_p=0.7, temperature=0.3)\n", | ||
"inputs['input_ids'] = torch.tensor(inputs['input_ids'])[None].cuda()\n", | ||
"generate_ids = model.generate(generation_config=generation_config, **inputs)\n", | ||
"print(generate_ids)\n", | ||
"print(tokenizer.decode(generate_ids[0]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "95f72cea-a972-4e0c-aad4-4a6f1a5aca25", | ||
"metadata": {}, | ||
"source": [ | ||
"10. 界面" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "bbf4c631-6d3f-4813-8380-96fbf632dac9", | ||
"metadata": { | ||
"ExecutionIndicator": { | ||
"show": true | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install gradio==3.50.2\n", | ||
"!swift web-ui" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters