Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
nickcom007 committed Dec 31, 2024
1 parent 0d6c23e commit 1e50376
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 28 deletions.
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
24 changes: 16 additions & 8 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from loguru import logger
from torch.utils.data import Dataset
from utils.tool_utils import tool_formater, function_formatter
from utils.tool_utils import function_formatter


class SFTDataset(Dataset):
Expand Down Expand Up @@ -50,22 +50,30 @@ def __getitem__(self, index):

if role != "assistant":
if role == "user":
human = self.user_format.format(content=content, stop_token=self.tokenizer.eos_token)
human = self.user_format.format(
content=content, stop_token=self.tokenizer.eos_token
)
input_buffer += human

elif role == "function_call":
tool_calls = function_formatter(json.loads(content))
function = self.function_format.format(content=tool_calls)
input_buffer += function

elif role == "observation":
observation = self.observation_format.format(content=content)
input_buffer += observation
else:
assistant = self.assistant_format.format(content=content, stop_token=self.tokenizer.eos_token)

input_tokens = self.tokenizer.encode(input_buffer, add_special_tokens=False)
output_tokens = self.tokenizer.encode(assistant, add_special_tokens=False)
assistant = self.assistant_format.format(
content=content, stop_token=self.tokenizer.eos_token
)

input_tokens = self.tokenizer.encode(
input_buffer, add_special_tokens=False
)
output_tokens = self.tokenizer.encode(
assistant, add_special_tokens=False
)

input_ids += input_tokens + output_tokens
target_mask += [0] * len(input_tokens) + [1] * len(output_tokens)
Expand Down
9 changes: 3 additions & 6 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from trl import SFTTrainer, SFTConfig

from dataset import SFTDataCollator, SFTDataset
from merge import merge_lora_to_base_model
from utils.constants import model2template


Expand Down Expand Up @@ -96,8 +95,8 @@ def train_lora(
# upload lora weights and tokenizer
print("Training Completed.")


if __name__ == "__main__":

# Define training arguments for LoRA fine-tuning
training_args = LoraTrainingArguments(
num_train_epochs=3,
Expand All @@ -114,7 +113,5 @@ def train_lora(

# Start LoRA fine-tuning
train_lora(
model_id=model_id,
context_length=context_length,
training_args=training_args
)
model_id=model_id, context_length=context_length, training_args=training_args
)
3 changes: 1 addition & 2 deletions full_automation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
import time

import requests
import yaml
Expand Down Expand Up @@ -69,7 +68,7 @@
exist_ok=False,
repo_type="model",
)
except Exception as e:
except Exception:
logger.info(
f"Repo {repo_name} already exists. Will commit the new version."
)
Expand Down
1 change: 0 additions & 1 deletion merge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down
25 changes: 14 additions & 11 deletions utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
DEFAULT_FUNCTION_SLOTS = "Action: {name}\nAction Input: {arguments}\n"



def tool_formater(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
Expand All @@ -29,7 +28,9 @@ def tool_formater(tools: List[Dict[str, Any]]) -> str:
enum = ", should be one of [{}]".format(", ".join(param["enum"]))

if param.get("items", None):
items = ", where each item should be {}".format(param["items"].get("type", ""))
items = ", where each item should be {}".format(
param["items"].get("type", "")
)

param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name,
Expand All @@ -45,22 +46,24 @@ def tool_formater(tools: List[Dict[str, Any]]) -> str:
)
tool_names.append(tool["name"])

return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
return DEFAULT_TOOL_PROMPT.format(
tool_text=tool_text, tool_names=", ".join(tool_names)
)


def function_formatter(tool_calls, function_slots=DEFAULT_FUNCTION_SLOTS) -> str:
functions : List[Tuple[str, str]] = []
functions: List[Tuple[str, str]] = []
if not isinstance(tool_calls, list):
tool_calls = [tool_calls] # parrallel function calls
tool_calls = [tool_calls] # parrallel function calls

for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))

functions.append(
(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)

elements = []
for name, arguments in functions:
text = function_slots.format(name=name, arguments=arguments)
elements.append(text)

return "\n".join(elements)+"\n"


return "\n".join(elements) + "\n"

0 comments on commit 1e50376

Please sign in to comment.