Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

少量数据训练正常,大量数据时训练loss逐渐变大,有没有什么排查的方向? #706

Open
will624 opened this issue Jan 19, 2025 · 5 comments
Assignees

Comments

@will624
Copy link

will624 commented Jan 19, 2025

在数据量比较多的情况下,训练初期loss先降低,后来慢慢变大,并且开了max_grad_norm。想问问看有没有什么排查的方向看是哪里的问题。训练机器是a800,以下是训练配置:

data_config:
  train_file: train.finetune.list.shuf
#  val_file: dev.list
#  test_file: test.list
  num_proc: 40

combine: True
freezeV: True
max_input_length: 512
max_output_length: 512

training_args:
  # see `transformers.Seq2SeqTrainingArguments`
  output_dir: ./output
  num_train_epochs: 2
  #max_steps: 27812500
  # needed to be fit for the dataset
  learning_rate: 5e-4
  lr_scheduler_type : cosine
  #lr_scheduler_type : constant_with_warmup
  # settings for data loading
  per_device_train_batch_size: 8
  dataloader_num_workers: 16
  remove_unused_columns: false
  gradient_accumulation_steps: 4
  max_grad_norm: 1.0
  # settings for saving checkpoints
  #save_strategy: epoch
  save_strategy: steps
  save_steps: 40000
  # settings for logging
  log_level: info
  #log_level: debug
  logging_strategy: steps
  logging_steps: 4
  # settings for evaluation
  #per_device_eval_batch_size: 4
  #eval_strategy: epoch
  #eval_steps: 500
  # settings for optimizer
  # adam_epsilon: 1e-6
  # uncomment the following line to detect nan or inf values
  # debug: underflow_overflow
  predict_with_generate: true
  # see `transformers.GenerationConfig`
  generation_config:
    max_new_tokens: 512
  # set your absolute deepspeed path here
  # deepspeed: configs/ds_zero_3.json
peft_config:
  peft_type: LORA
  task_type: CAUSAL_LM
  r: 8
  lora_alpha: 16
  lora_dropout: 0.1
  target_modules: ["query_key_value"]
  #target_modules: ["q_proj", "k_proj", "v_proj"] if model is glm-4-9b-chat-hf

以下是训练代码:

# -*- coding: utf-8 -*-
import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, Split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    EvalPrediction,
    GenerationConfig,
    PreTrainedTokenizer,
    Seq2SeqTrainingArguments,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional

# For Ascend NPU, please add this
# import torch_npu
# from torch_npu.contrib import transfer_to_npu

#prompt = "给以下文字加上标点符号,包括逗号、句号、问号、感叹号,不能增加、删除、修改文字,也不可以调整文字顺序,尽量使用句号和逗号:\n"
prompt = "保留以下文字的全部内容,不要增加、删除、替换任何文字,不要做语义通顺性方面的修改,不要去掉语气词、重复词等,在此基础上结合语义插入标点符号,包括逗号、句号、问号、感叹号,尽量使用句号和逗号。将标点符号删除后的文本内容应该跟输入的文本内容一模一样,以下是需要插入标点的文本:"
app = typer.Typer(pretty_exceptions_show_locals=False)


class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
        output_ids = (
            [feature["output_ids"] for feature in features]
            if "output_ids" in features[0].keys()
            else None
        )
        if output_ids is not None:
            max_output_length = max(len(out) for out in output_ids)
            if self.pad_to_multiple_of is not None:
                max_output_length = (
                    (max_output_length + self.pad_to_multiple_of - 1)
                    // self.pad_to_multiple_of
                    * self.pad_to_multiple_of
                )
            for feature in features:
                remainder = [self.tokenizer.pad_token_id] * (
                    max_output_length - len(feature["output_ids"])
                )
                if isinstance(feature["output_ids"], list):
                    feature["output_ids"] = feature["output_ids"] + remainder
                else:
                    feature["output_ids"] = np.concatenate(
                        [feature["output_ids"], remainder]
                    ).astype(np.int64)
        return super().__call__(features, return_tensors)


class Seq2SeqTrainer(_Seq2SeqTrainer):
    def prediction_step(
        self,
        model: nn.Module,
        inputs: dict[str, Any],
        prediction_loss_only: bool,
        ignore_keys=None,
        **gen_kwargs,
    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        with torch.no_grad():  # Ensure no gradient computation
            if self.args.predict_with_generate:
                output_ids = inputs.pop("output_ids")
            input_ids = inputs["input_ids"]

            loss, generated_tokens, labels = super().prediction_step(
                model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
            )

            generated_tokens = generated_tokens[:, input_ids.size()[1] :]
            labels = output_ids

            del inputs, input_ids, output_ids
            torch.cuda.empty_cache()

        return loss, generated_tokens, labels


@dc.dataclass
class DataConfig(object):
    train_file: Optional[str] = None
    num_proc: Optional[int] = None

    #@property
    #def data_format(self) -> str:
    #    return Path(self.train_file).suffix

    def _read_file_list(self, file_path: Optional[str]) -> list[str]:
        """Helper function to read a file containing file paths."""
        if file_path and os.path.exists(file_path):
            with open(file_path, "r", encoding="utf-8") as f:
                return [line.strip() for line in f if line.strip()]
        return []

    @property
    def data_files(self) -> dict[NamedSplit, str]:
        return {
            split: self._read_file_list(data_file)
            for split, data_file in zip(
                [Split.TRAIN],
                [self.train_file],
            )
            if data_file is not None
        }


@dc.dataclass
class FinetuningConfig(object):
    data_config: DataConfig

    max_input_length: int
    max_output_length: int
    combine: bool
    freezeV: bool

    training_args: Seq2SeqTrainingArguments = dc.field(
        default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output")
    )
    peft_config: Optional[PeftConfig] = None

    def __post_init__(self):
        if not self.training_args.do_eval or self.data_config.val_file is None:
            self.training_args.do_eval = False
            self.training_args.evaluation_strategy = "no"
            self.data_config.val_file = None
        else:
            self.training_args.per_device_eval_batch_size = (
                self.training_args.per_device_eval_batch_size
                or self.training_args.per_device_train_batch_size
            )

    @classmethod
    def from_dict(cls, **kwargs) -> "FinetuningConfig":
        training_args = kwargs.get("training_args", None)
        if training_args is not None and not isinstance(
            training_args, Seq2SeqTrainingArguments
        ):
            gen_config = training_args.get("generation_config")
            if not isinstance(gen_config, GenerationConfig):
                training_args["generation_config"] = GenerationConfig(**gen_config)
            kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args)

        data_config = kwargs.get("data_config")
        if not isinstance(data_config, DataConfig):
            kwargs["data_config"] = DataConfig(**data_config)

        peft_config = kwargs.get("peft_config", None)
        if peft_config is not None and not isinstance(peft_config, PeftConfig):
            kwargs["peft_config"] = get_peft_config(config_dict=peft_config)
        return cls(**kwargs)

    @classmethod
    def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig":
        path = Path(path)
        parser = yaml.YAML(typ="safe", pure=True)
        parser.indent(mapping=2, offset=2, sequence=4)
        parser.default_flow_style = False
        kwargs = parser.load(path)
        return cls.from_dict(**kwargs)

def add_prefix(example):
    #print(example["messages"])
    for iterm in example["messages"]:
        if iterm["role"] == "user":
            iterm["content"] = prompt + iterm["content"]
    return example

def _load_datasets(
    #data_dir: str,
    data_format: str,
    data_files: dict[NamedSplit, str],
    num_proc: Optional[int],
) -> DatasetDict:
    if data_format == ".jsonl":
        dataset_dct = load_dataset(
            #data_dir,
            "json",
            data_files=data_files,
            split=None,
            num_proc=num_proc,
            #streaming=True,
        )
        #dataset_dct = dataset_dct.map(add_prefix)
    else:
        raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
    return dataset_dct


class DataManager(object):
    def __init__(self, data_config: DataConfig):
        self._num_proc = data_config.num_proc
#data_dir=data/, data_config.data_format=jsonl, data_config.data_files=
        self._dataset_dct = _load_datasets(
            ".jsonl",
            data_config.data_files,
            self._num_proc,
        )

    def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
        return self._dataset_dct.get(split, None)

    def get_dataset(
        self,
        split: NamedSplit,
        process_fn: Callable[[dict[str, Any]], dict[str, Any]],
        batched: bool = True,
        remove_orig_columns: bool = True,
    ) -> Optional[Dataset]:
        orig_dataset = self._get_dataset(split)
        if orig_dataset is None:
            return

        if remove_orig_columns:
            remove_columns = orig_dataset.column_names
        else:
            remove_columns = None
        return orig_dataset.map(
            process_fn,
            batched=batched,
            remove_columns=remove_columns,
            num_proc=self._num_proc,
        )


def process_message(message):
    if "tools" in message and message["role"] == "system":
        for tool in message["tools"]:
            parameters = tool["function"]["parameters"]["properties"]
            tool["function"]["parameters"]["properties"] = {
                k: v for k, v in parameters.items() if v is not None
            }
    elif "tools" in message:
        del message["tools"]
    return message


def process_batch(
    batch: Mapping[str, Sequence],
    tokenizer: PreTrainedTokenizer,
    max_input_length: int,
    max_output_length: int,
    combine: bool,
) -> dict[str, list]:
    batched_conv = batch["messages"]
    batched_input_ids = []
    batched_labels = []
    for conv in batched_conv:
        input_ids = [151331, 151333]
        loss_masks = [False, False]
        for i in range(len(conv)):
            if conv[i]["role"] == "user":
                conv[i]["content"] = prompt + conv[i]["content"]

        if combine:
            new_input_ids = tokenizer.apply_chat_template(
                conv, tokenize=True, return_dict=False
            )
            input_ids = new_input_ids
            loss_masks = [False] * len(input_ids)
            last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
            for j in range(last_assistant_index + 1, len(input_ids)):
                loss_masks[j] = True
        else:
            for message in conv:
                message = process_message(message)
                loss_mask_val = (
                    False
                    if message["role"] in ("system", "user", "observation")
                    else True
                )
                new_input_ids = tokenizer.apply_chat_template(
                    [message], tokenize=True, return_dict=False
                )[2:]
                input_ids += new_input_ids
                loss_masks += [loss_mask_val] * len(new_input_ids)

        input_ids.append(151336)  # EOS for chat
        loss_masks = [False, *loss_masks]
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                labels.append(input_id)
            else:
                labels.append(-100)
        max_length = max_input_length + max_output_length + 1
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])

    del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
    torch.cuda.empty_cache()

    return {"input_ids": batched_input_ids, "labels": batched_labels}


def process_batch_eval(
    batch: Mapping[str, Sequence],
    tokenizer: PreTrainedTokenizer,
    max_input_length: int,
    max_output_length: int,
    combine: bool,
) -> dict[str, list]:
    batched_conv = batch["messages"]
    batched_input_ids = []
    batched_output_ids = []

    for conv in batched_conv:
        if combine:
            new_input_ids = tokenizer.apply_chat_template(
                conv, tokenize=True, return_dict=False
            )
            input_ids = new_input_ids
            last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
            output_prompt, output_ids = (
                input_ids[:1],
                input_ids[last_assistant_index:],
            )
            output_ids.append(151336)
            batched_input_ids.append(input_ids[:max_input_length] + output_prompt[:1])
            batched_output_ids.append(output_ids[:max_output_length])
        else:
            input_ids = [151331, 151333]
            for message in conv:
                if len(input_ids) >= max_input_length:
                    break
                else:
                    message = process_message(message)
                    new_input_ids = tokenizer.apply_chat_template(
                        [message], tokenize=True, return_dict=False
                    )[2:]
                    if message["role"] == "assistant":
                        output_prompt, output_ids = (
                            new_input_ids[:1],
                            new_input_ids[1:],
                        )
                        output_ids.append(151336)
                        batched_input_ids.append(
                            input_ids[:max_input_length] + output_prompt[:1]
                        )
                        batched_output_ids.append(output_ids[:max_output_length])
                    input_ids += new_input_ids

    del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
    torch.cuda.empty_cache()

    return {"input_ids": batched_input_ids, "output_ids": batched_output_ids}


def load_tokenizer_and_model(
    model_dir: str,
    peft_config: Optional[PeftConfig] = None,
):
    tokenizer = AutoTokenizer.from_pretrained(
        model_dir, padding_side="left", trust_remote_code=True
    )
    if peft_config is not None:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            use_cache=False,
            torch_dtype=torch.bfloat16,  # Must use BFloat 16
        )
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            use_cache=False,
            torch_dtype=torch.bfloat16,
        )
    return tokenizer, model


def compute_metrics(eval_preds: EvalPrediction, tokenizer):
    batched_pred_ids, batched_label_ids = eval_preds
    batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id
    batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id
    metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
    for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
        if len(pred_ids) == 0 and len(label_ids) == 0 :
            continue
        pred_txt = tokenizer.decode(pred_ids).strip()
        label_txt = tokenizer.decode(label_ids).strip()
        pred_tokens = list(jieba.cut(pred_txt))
        label_tokens = list(jieba.cut(label_txt))
        rouge = Rouge()
        scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens))
        for k, v in scores[0].items():
            metrics_dct[k].append(round(v["f"] * 100, 4))
        metrics_dct["bleu-4"].append(
            sentence_bleu(
                [label_tokens],
                pred_tokens,
                smoothing_function=SmoothingFunction().method3,
            )
        )
    return {k: np.mean(v) for k, v in metrics_dct.items()}


@app.command()
def main(
    #data_dir: Annotated[str, typer.Argument(help="")],
    model_dir: Annotated[
        str,
        typer.Argument(
            help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file."
        ),
    ],
    config_file: Annotated[str, typer.Argument(help="")],
    auto_resume_from_checkpoint: str = typer.Argument(
        default="",
        help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training",
    ),
):
    ft_config = FinetuningConfig.from_file(config_file)
    tokenizer, model = load_tokenizer_and_model(
        model_dir, peft_config=ft_config.peft_config
    )
    data_manager = DataManager(ft_config.data_config)

    train_dataset = data_manager.get_dataset(
        Split.TRAIN,
        functools.partial(
            process_batch,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    print("train_dataset:", train_dataset)
    #val_dataset = data_manager.get_dataset(
    #    Split.VALIDATION,
    #    functools.partial(
    #        process_batch_eval,
    #        tokenizer=tokenizer,
    #        combine=ft_config.combine,
    #        max_input_length=ft_config.max_input_length,
    #        max_output_length=ft_config.max_output_length,
    #    ),
    #    batched=True,
    #)
    #if val_dataset is not None:
    #    print("val_dataset:", val_dataset)
    #test_dataset = data_manager.get_dataset(
    #    Split.TEST,
    #    functools.partial(
    #        process_batch_eval,
    #        tokenizer=tokenizer,
    #        combine=ft_config.combine,
    #        max_input_length=ft_config.max_input_length,
    #        max_output_length=ft_config.max_output_length,
    #    ),
    #    batched=True,
    #)
    #if test_dataset is not None:
    #    print("test_dataset:", test_dataset)

    ft_config.training_args.generation_config.pad_token_id = 151329
    ft_config.training_args.generation_config.eos_token_id = [151329, 151336, 151338]

    trainer = Seq2SeqTrainer(
        model=model,
        args=ft_config.training_args,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            padding="longest",
            return_tensors="pt",
        ),
        train_dataset=train_dataset,
        #eval_dataset=val_dataset,
        #compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
    )

    if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
        trainer.train()
    else:
        output_dir = ft_config.training_args.output_dir
        dirlist = os.listdir(output_dir)
        checkpoint_sn = 0
        for checkpoint_str in dirlist:
            if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
                checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
                if checkpoint > checkpoint_sn:
                    checkpoint_sn = checkpoint
        if auto_resume_from_checkpoint.upper() == "YES":
            if checkpoint_sn > 0:
                model.gradient_checkpointing_enable()
                model.enable_input_require_grads()
                checkpoint_directory = os.path.join(
                    output_dir, "checkpoint-" + str(checkpoint_sn)
                )
                print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
                trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
                trainer.train()
        else:
            if auto_resume_from_checkpoint.isdigit():
                if int(auto_resume_from_checkpoint) > 0:
                    checkpoint_sn = int(auto_resume_from_checkpoint)
                    model.gradient_checkpointing_enable()
                    model.enable_input_require_grads()
                    checkpoint_directory = os.path.join(
                        output_dir, "checkpoint-" + str(checkpoint_sn)
                    )
                    print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
                    trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
                print(
                    auto_resume_from_checkpoint,
                    "The specified checkpoint sn("
                    + auto_resume_from_checkpoint
                    + ") has not been saved. Please search for the correct checkpoint in the model output directory",
                )

    #if test_dataset is not None:
    #    trainer.predict(test_dataset)


if __name__ == "__main__":
    app()

以下是部分日志:
train_dataset: Dataset({
features: ['input_ids', 'labels'],
num_rows: 180431597
})
[2025-01-18 23:24:31,286] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
[WARNING] using untested triton version (3.1.0), only 1.0.0 is known to be compatible
/mnt/LM_disk12/weichenchuang/env/conda_for_hg.3.10py/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:49: FutureWarning: torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead.
def forward(ctx, input, weight, bias=None):
/mnt/LM_disk12/weichenchuang/env/conda_for_hg.3.10py/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:67: FutureWarning: torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead.
def backward(ctx, grad_output):
NCCL version 2.21.5+cuda12.4
[2025-01-18 23:24:34,048] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
[2025-01-18 23:24:34,719] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
[WARNING] using untested triton version (3.1.0), only 1.0.0 is known to be compatible
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
/mnt/LM_disk12/weichenchuang/env/conda_for_hg.3.10py/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:49: FutureWarning: torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead.
def forward(ctx, input, weight, bias=None):
/mnt/LM_disk12/weichenchuang/env/conda_for_hg.3.10py/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:67: FutureWarning: torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead.
def backward(ctx, grad_output):
[2025-01-18 23:24:35,040] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
[WARNING] using untested triton version (3.1.0), only 1.0.0 is known to be compatible
/mnt/LM_disk12/weichenchuang/env/conda_for_hg.3.10py/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:49: FutureWarning: torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead.
def forward(ctx, input, weight, bias=None):
/mnt/LM_disk12/weichenchuang/env/conda_for_hg.3.10py/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:67: FutureWarning: torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead.
def backward(ctx, grad_output):
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
[WARNING] using untested triton version (3.1.0), only 1.0.0 is known to be compatible
/mnt/LM_disk12/weichenchuang/env/conda_for_hg.3.10py/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:49: FutureWarning: torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead.
def forward(ctx, input, weight, bias=None):
/mnt/LM_disk12/weichenchuang/env/conda_for_hg.3.10py/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:67: FutureWarning: torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead.
def backward(ctx, grad_output):
[2025-01-18 23:24:36,284] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-01-18 23:24:36,299] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH

以下是loss曲线:

Image

以下是grad_norm比较大的loss:
{'loss': 0.094, 'grad_norm': 3.6441407203674316, 'learning_rate': 0.0004997769340792326, 'epoch': 0.03}
{'loss': 0.1019, 'grad_norm': 3.8415417671203613, 'learning_rate': 0.0004997768399428072, 'epoch': 0.03}
{'loss': 0.0922, 'grad_norm': 3.0973892211914062, 'learning_rate': 0.0004997767457865314, 'epoch': 0.03}
{'loss': 0.0894, 'grad_norm': 2.5256614685058594, 'learning_rate': 0.0004997766516104055, 'epoch': 0.03}
{'loss': 0.0928, 'grad_norm': 2.3791558742523193, 'learning_rate': 0.000499776557414429, 'epoch': 0.03}
{'loss': 0.0949, 'grad_norm': 2.533939838409424, 'learning_rate': 0.0004997764631986026, 'epoch': 0.03}
{'loss': 0.0946, 'grad_norm': 5.6503705978393555, 'learning_rate': 0.0004997763689629258, 'epoch': 0.03}
{'loss': 0.1045, 'grad_norm': 237.627197265625, 'learning_rate': 0.0004997762747073987, 'epoch': 0.03}
{'loss': 0.1246, 'grad_norm': 4.959110260009766, 'learning_rate': 0.0004997761804320215, 'epoch': 0.03}
{'loss': 0.0988, 'grad_norm': 5.715487003326416, 'learning_rate': 0.0004997760861367939, 'epoch': 0.03}
{'loss': 0.1025, 'grad_norm': 3.7426841259002686, 'learning_rate': 0.0004997759918217163, 'epoch': 0.03}
{'loss': 0.0932, 'grad_norm': 3.6973986625671387, 'learning_rate': 0.0004997758974867883, 'epoch': 0.03}
{'loss': 0.0963, 'grad_norm': 8.751940727233887, 'learning_rate': 0.0004997758031320102, 'epoch': 0.03}
{'loss': 0.0924, 'grad_norm': 4.516246795654297, 'learning_rate': 0.0004997757087573818, 'epoch': 0.03}
{'loss': 0.1054, 'grad_norm': 2.862150192260742, 'learning_rate': 0.0004997756143629034, 'epoch': 0.03}
{'loss': 0.0935, 'grad_norm': 3.533675193786621, 'learning_rate': 0.0004997755199485747, 'epoch': 0.03}
{'loss': 0.094, 'grad_norm': 4.6160688400268555, 'learning_rate': 0.0004997754255143959, 'epoch': 0.03}

@zhipuch zhipuch self-assigned this Jan 20, 2025
@zhipuch
Copy link
Collaborator

zhipuch commented Jan 20, 2025

前面可以直接停了吧,loss很快最小且平稳

@will624
Copy link
Author

will624 commented Jan 20, 2025

前面可以直接停了吧,loss很快最小且平稳

就过了几个样本啊,这里应该是训练有问题。
另外问一下,训练是要用bf16混合精度,还是完全用bf16训练?
@zhipuch

@zRzRzRzRzRzRzR
Copy link
Member

都可以,可以直接用BF16就行

@will624
Copy link
Author

will624 commented Jan 20, 2025

都可以,可以直接用BF16就行

好的,再请教下,lora训练的过程中,我看glm的参数有一部分是bf16的,lora的参数基本上都是fp32的,这样是正确的吗?
@zRzRzRzRzRzRzR

@zRzRzRzRzRzRzR
Copy link
Member

应该是按照bf16微调的,fp32是部分操作需要转fp32计算,不影响的吧。你说的fp32的权重(就是需要保存的权重)在哪里?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants