Skip to content

Commit

Permalink
fixing issues and quality ✨
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 committed Mar 7, 2023
1 parent a84414f commit b9451ab
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import threading

import numpy as np
import psutil
import torch
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Expand All @@ -15,10 +18,7 @@
set_seed,
)

import psutil
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from tqdm import tqdm


def levenshtein_distance(str1, str2):
Expand Down Expand Up @@ -280,7 +280,9 @@ def test_preprocess_function(examples):
outputs = accelerator.unwrap_model(model).generate(
**batch, synced_gpus=is_ds_zero_3, max_new_tokens=10
) # synced_gpus=True for DS-stage 3
preds = outputs[:, max_length:].detach().cpu().numpy()
outputs = accelerator.pad_across_processes(outputs, dim=1, pad_index=tokenizer.pad_token_id)
preds = accelerator.gather(outputs)
preds = preds[:, max_length:].detach().cpu().numpy()
eval_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))

# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
Expand All @@ -304,6 +306,9 @@ def test_preprocess_function(examples):

correct = 0
total = 0
assert len(eval_preds) == len(
dataset["train"][label_column]
), f"{len(eval_preds)} != {len(dataset['train'][label_column])}"
for pred, true in zip(eval_preds, dataset["train"][label_column]):
if pred.strip() == true.strip():
correct += 1
Expand All @@ -322,15 +327,17 @@ def test_preprocess_function(examples):
outputs = accelerator.unwrap_model(model).generate(
**batch, synced_gpus=is_ds_zero_3, max_new_tokens=10
) # synced_gpus=True for DS-stage 3
test_preds.extend(
tokenizer.batch_decode(outputs[:, max_length:].detach().cpu().numpy(), skip_special_tokens=True)
)
outputs = accelerator.pad_across_processes(outputs, dim=1, pad_index=tokenizer.pad_token_id)
preds = accelerator.gather(outputs)
preds = preds[:, max_length:].detach().cpu().numpy()
test_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))

test_preds_cleaned = []
for _, pred in enumerate(test_preds):
test_preds_cleaned.append(get_closest_label(pred, classes))

test_df = dataset["test"].to_pandas()
assert len(test_preds_cleaned) == len(test_df), f"{len(test_preds_cleaned)} != {len(test_df)}"
test_df[label_column] = test_preds_cleaned
test_df["text_labels_orig"] = test_preds
accelerator.print(test_df[[text_column, label_column]].sample(20))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import threading

import numpy as np
import psutil
import torch
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup, set_seed

import psutil
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from tqdm import tqdm


def levenshtein_distance(str1, str2):
Expand Down Expand Up @@ -230,7 +230,8 @@ def collate_fn(examples):
outputs = accelerator.unwrap_model(model).generate(
**batch, synced_gpus=is_ds_zero_3
) # synced_gpus=True for DS-stage 3
preds = outputs.detach().cpu().numpy()
outputs = accelerator.pad_across_processes(outputs, dim=1, pad_index=tokenizer.pad_token_id)
preds = accelerator.gather(outputs).detach().cpu().numpy()
eval_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))

# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
Expand All @@ -254,6 +255,9 @@ def collate_fn(examples):

correct = 0
total = 0
assert len(eval_preds) == len(
dataset["train"][label_column]
), f"{len(eval_preds)} != {len(dataset['train'][label_column])}"
for pred, true in zip(eval_preds, dataset["train"][label_column]):
if pred.strip() == true.strip():
correct += 1
Expand All @@ -272,13 +276,16 @@ def collate_fn(examples):
outputs = accelerator.unwrap_model(model).generate(
**batch, synced_gpus=is_ds_zero_3
) # synced_gpus=True for DS-stage 3
test_preds.extend(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))
outputs = accelerator.pad_across_processes(outputs, dim=1, pad_index=tokenizer.pad_token_id)
preds = accelerator.gather(outputs).detach().cpu().numpy()
test_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))

test_preds_cleaned = []
for _, pred in enumerate(test_preds):
test_preds_cleaned.append(get_closest_label(pred, classes))

test_df = dataset["test"].to_pandas()
assert len(test_preds_cleaned) == len(test_df), f"{len(test_preds_cleaned)} != {len(test_df)}"
test_df[label_column] = test_preds_cleaned
test_df["text_labels_orig"] = test_preds
accelerator.print(test_df[[text_column, label_column]].sample(20))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import torch
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup

from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from peft.utils.other import fsdp_auto_wrap_policy
from tqdm import tqdm


def main():
Expand Down
14 changes: 7 additions & 7 deletions examples/lora_dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,17 @@
from pathlib import Path
from typing import Optional

import datasets
import diffusers
import numpy as np
import psutil
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from torch.utils.data import Dataset
from transformers import AutoTokenizer, PretrainedConfig

import datasets
import diffusers
import psutil
from diffusers import (
AutoencoderKL,
DDPMScheduler,
Expand All @@ -36,10 +33,13 @@
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami
from peft import LoraConfig, LoraModel, get_peft_model_state_dict
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig

from peft import LoraConfig, LoraModel, get_peft_model_state_dict


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ select = ["C", "E", "F", "I", "W"]
line-length = 119

[tool.ruff.isort]
lines-after-imports = 2
known-first-party = ["peft"]

[isort]
default_section = "FIRSTPARTY"
known_first_party = "peft"
known_third_party = [
Expand Down
8 changes: 4 additions & 4 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory
from huggingface_hub import hf_hub_download
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

from huggingface_hub import hf_hub_download

from .tuners import LoraModel, PrefixEncoder, PromptEmbedding, PromptEncoder
from .utils import (
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
Expand Down Expand Up @@ -156,7 +155,8 @@ def from_pretrained(cls, model, model_id, **kwargs):
)

adapters_weights = torch.load(
filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
# load the weights into the model
model = set_peft_model_state_dict(model, adapters_weights)
if getattr(model, "hf_device_map", None) is not None:
Expand Down Expand Up @@ -271,7 +271,7 @@ def print_trainable_parameters(self):
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel

all_param += num_params
if param.requires_grad:
trainable_params += param.numel()
Expand Down
14 changes: 7 additions & 7 deletions src/peft/tuners/p_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import enum
import warnings
from dataclasses import dataclass, field
from typing import Union

Expand Down Expand Up @@ -131,17 +132,16 @@ def __init__(self, config):
)

elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
warnings.warn(
f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."
)
layers = [
torch.nn.Linear(self.input_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, self.output_size),
]
layers.extend(
[
torch.nn.Linear(self.hidden_size, self.hidden_size),
torch.nn.ReLU(),
]
)
layers.append(torch.nn.Linear(self.hidden_size, self.output_size))
self.mlp_head = torch.nn.Sequential(*layers)

else:
Expand Down
5 changes: 2 additions & 3 deletions src/peft/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
from dataclasses import asdict, dataclass, field
from typing import Optional, Union

from transformers.utils import PushToHubMixin

from huggingface_hub import hf_hub_download
from transformers.utils import PushToHubMixin

from .adapters_utils import CONFIG_NAME

Expand Down Expand Up @@ -98,7 +97,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
else:
try:
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME)
except:
except Exception:
raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'")

loaded_attributes = cls.from_json_file(config_file)
Expand Down
42 changes: 21 additions & 21 deletions tests/test_peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,27 @@ class PeftTestMixin:
PromptTuningConfig,
)
config_kwargs = (
dict(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
),
dict(
num_virtual_tokens=10,
task_type="CAUSAL_LM",
),
dict(
num_virtual_tokens=10,
encoder_hidden_size=32,
task_type="CAUSAL_LM",
),
dict(
num_virtual_tokens=10,
task_type="CAUSAL_LM",
),
{
"r": 8,
"lora_alpha": 32,
"target_modules": ["q_proj", "v_proj"],
"lora_dropout": 0.05,
"bias": "none",
"task_type": "CAUSAL_LM",
},
{
"num_virtual_tokens": 10,
"task_type": "CAUSAL_LM",
},
{
"num_virtual_tokens": 10,
"encoder_hidden_size": 32,
"task_type": "CAUSAL_LM",
},
{
"num_virtual_tokens": 10,
"task_type": "CAUSAL_LM",
},
)


Expand Down

0 comments on commit b9451ab

Please sign in to comment.