From 1e50376cc325d4c83cdf315e77afbe2f677061c7 Mon Sep 17 00:00:00 2001 From: Nick W Date: Tue, 31 Dec 2024 05:40:19 -0500 Subject: [PATCH] chore: lint --- .pre-commit-config.yaml | 13 +++++++++++++ dataset.py | 24 ++++++++++++++++-------- demo.py | 9 +++------ full_automation.py | 3 +-- merge.py | 1 - utils/tool_utils.py | 25 ++++++++++++++----------- 6 files changed, 47 insertions(+), 28 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..57d62e6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.4 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/dataset.py b/dataset.py index 634abce..c717595 100644 --- a/dataset.py +++ b/dataset.py @@ -4,7 +4,7 @@ import torch from loguru import logger from torch.utils.data import Dataset -from utils.tool_utils import tool_formater, function_formatter +from utils.tool_utils import function_formatter class SFTDataset(Dataset): @@ -50,22 +50,30 @@ def __getitem__(self, index): if role != "assistant": if role == "user": - human = self.user_format.format(content=content, stop_token=self.tokenizer.eos_token) + human = self.user_format.format( + content=content, stop_token=self.tokenizer.eos_token + ) input_buffer += human - + elif role == "function_call": tool_calls = function_formatter(json.loads(content)) function = self.function_format.format(content=tool_calls) input_buffer += function - + elif role == "observation": observation = self.observation_format.format(content=content) input_buffer += observation else: - assistant = self.assistant_format.format(content=content, stop_token=self.tokenizer.eos_token) - - input_tokens = self.tokenizer.encode(input_buffer, add_special_tokens=False) - output_tokens = self.tokenizer.encode(assistant, add_special_tokens=False) + assistant = self.assistant_format.format( + content=content, stop_token=self.tokenizer.eos_token + ) + + input_tokens = self.tokenizer.encode( + input_buffer, add_special_tokens=False + ) + output_tokens = self.tokenizer.encode( + assistant, add_special_tokens=False + ) input_ids += input_tokens + output_tokens target_mask += [0] * len(input_tokens) + [1] * len(output_tokens) diff --git a/demo.py b/demo.py index eb83efc..4fb1b73 100644 --- a/demo.py +++ b/demo.py @@ -7,7 +7,6 @@ from trl import SFTTrainer, SFTConfig from dataset import SFTDataCollator, SFTDataset -from merge import merge_lora_to_base_model from utils.constants import model2template @@ -96,8 +95,8 @@ def train_lora( # upload lora weights and tokenizer print("Training Completed.") + if __name__ == "__main__": - # Define training arguments for LoRA fine-tuning training_args = LoraTrainingArguments( num_train_epochs=3, @@ -114,7 +113,5 @@ def train_lora( # Start LoRA fine-tuning train_lora( - model_id=model_id, - context_length=context_length, - training_args=training_args - ) \ No newline at end of file + model_id=model_id, context_length=context_length, training_args=training_args + ) diff --git a/full_automation.py b/full_automation.py index 1957be7..d1efd5e 100644 --- a/full_automation.py +++ b/full_automation.py @@ -1,6 +1,5 @@ import json import os -import time import requests import yaml @@ -69,7 +68,7 @@ exist_ok=False, repo_type="model", ) - except Exception as e: + except Exception: logger.info( f"Repo {repo_name} already exists. Will commit the new version." ) diff --git a/merge.py b/merge.py index 80cb6ab..2580bb3 100644 --- a/merge.py +++ b/merge.py @@ -1,4 +1,3 @@ -import json import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/utils/tool_utils.py b/utils/tool_utils.py index 2ac07a3..2aa2934 100644 --- a/utils/tool_utils.py +++ b/utils/tool_utils.py @@ -14,7 +14,6 @@ DEFAULT_FUNCTION_SLOTS = "Action: {name}\nAction Input: {arguments}\n" - def tool_formater(tools: List[Dict[str, Any]]) -> str: tool_text = "" tool_names = [] @@ -29,7 +28,9 @@ def tool_formater(tools: List[Dict[str, Any]]) -> str: enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("items", None): - items = ", where each item should be {}".format(param["items"].get("type", "")) + items = ", where each item should be {}".format( + param["items"].get("type", "") + ) param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( name=name, @@ -45,22 +46,24 @@ def tool_formater(tools: List[Dict[str, Any]]) -> str: ) tool_names.append(tool["name"]) - return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) + return DEFAULT_TOOL_PROMPT.format( + tool_text=tool_text, tool_names=", ".join(tool_names) + ) def function_formatter(tool_calls, function_slots=DEFAULT_FUNCTION_SLOTS) -> str: - functions : List[Tuple[str, str]] = [] + functions: List[Tuple[str, str]] = [] if not isinstance(tool_calls, list): - tool_calls = [tool_calls] # parrallel function calls - + tool_calls = [tool_calls] # parrallel function calls + for tool_call in tool_calls: - functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) - + functions.append( + (tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)) + ) + elements = [] for name, arguments in functions: text = function_slots.format(name=name, arguments=arguments) elements.append(text) - - return "\n".join(elements)+"\n" - + return "\n".join(elements) + "\n"