From 4a68efb7da274e1ab1550e3e018870e089d2a180 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 28 Aug 2024 17:01:58 +0800 Subject: [PATCH] [Colossal-LLaMA] Refactor latest APIs (#6030) * refactor latest code * update api * add dummy dataset * update Readme * add setup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update files * add PP support * update arguments * update argument * reorg folder * update version * remove IB infor * update utils * update readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update save for zero * update save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add apex * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- applications/Colossal-LLaMA/README.md | 36 +- .../colossal_llama/dataset/dummy_dataset.py | 24 + .../utils/flash_attention_patch.py | 352 -------------- .../colossal_llama/utils/utils.py | 36 ++ .../{ => dataset}/prepare_pretrain_dataset.py | 0 .../{ => dataset}/prepare_sft_dataset.py | 0 .../{ => inference}/inference_example.py | 0 .../{ => inference}/stream_chat_example.py | 0 applications/Colossal-LLaMA/requirements.txt | 6 +- applications/Colossal-LLaMA/setup.py | 37 ++ applications/Colossal-LLaMA/train.example.sh | 23 +- applications/Colossal-LLaMA/train.py | 428 +++++++++++------- applications/Colossal-LLaMA/version.txt | 2 +- 13 files changed, 396 insertions(+), 548 deletions(-) create mode 100644 applications/Colossal-LLaMA/colossal_llama/dataset/dummy_dataset.py delete mode 100644 applications/Colossal-LLaMA/colossal_llama/utils/flash_attention_patch.py create mode 100644 applications/Colossal-LLaMA/colossal_llama/utils/utils.py rename applications/Colossal-LLaMA/{ => dataset}/prepare_pretrain_dataset.py (100%) rename applications/Colossal-LLaMA/{ => dataset}/prepare_sft_dataset.py (100%) rename applications/Colossal-LLaMA/{ => inference}/inference_example.py (100%) rename applications/Colossal-LLaMA/{ => inference}/stream_chat_example.py (100%) create mode 100644 applications/Colossal-LLaMA/setup.py diff --git a/applications/Colossal-LLaMA/README.md b/applications/Colossal-LLaMA/README.md index 5997008e8729..e62b14390787 100644 --- a/applications/Colossal-LLaMA/README.md +++ b/applications/Colossal-LLaMA/README.md @@ -30,7 +30,7 @@ Colossal-LLaMA - [Install](#install) - [0. Pre-requisite](#0-pre-requisite) - [1. Install required packages](#1-install-required-packages) - - [2. Install `xentropy`, `layer_norm` and `rotary`](#2-install-xentropy-layer_norm-and-rotary) + - [2. Install Apex](#2-install-apex) - [How to run](#how-to-run) - [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation) - [2. Init Model Preparation](#2-init-model-preparation) @@ -297,17 +297,13 @@ Here is details about CLI arguments: #### 1. Install required packages ``` cd Colossal-LLaMA -pip install -r requirements.txt +pip install -e . ``` -#### 2. Install `xentropy`, `layer_norm` and `rotary` + +#### 2. Install Apex ```bash -git clone git@github.com:Dao-AILab/flash-attention.git -# At the root folder -cd csrc/xentropy && pip install . -# At the root folder -cd csrc/layer_norm && pip install . -# At the root folder -cd csrc/rotary && pip install . +git clone git@github.com:NVIDIA/apex.git +# Install from source. ``` ### How to run @@ -427,25 +423,33 @@ Make sure master node can access all nodes (including itself) by ssh without pas Here is details about CLI arguments: * Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format. * Dataset path: `--dataset`. Path to the pre-tokenized dataset. -* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/). +* Booster plugin: `--plugin`. `ddp`,`gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/). * Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training. * Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. * Checkpoint directory: `--save_dir`. The directory path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. * Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs. * Configuration file: `--config_file`. The path to save the configuration file. * Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1. -* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1. +* Batch size: `--batch_size`. Batch size per GPU. The default value is 1. For PP, it refers to number of samples per step. * Learning rate: `--lr`. The default value is 3e-4. * Max length: `--max_length`. Max context length. The default value is 4096. * Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. * Gradient clipping: `--gradient_clipping`. The default value is 1.0. -* Weight decay: `-w`, `--weight_decay`. The default value is 0.1. -* Warmup steps: `-s`, `--warmup_steps`. The default value is calculated by 0.025 warmup ratio. +* Weight decay: `--weight_decay`. The default value is 0.1. +* Warmup steps: `--warmup_steps`. The default value is calculated by 0.025 warmup ratio. * Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. * Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. * Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size. -* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1. -* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1. +* Tensor parallelism size: `--tp`. TP size for 3d parallelism. The default value is 1. Used for 3d plugin. +* Pipeline parallelism size: `--pp`. PP size for 3d parallelism. The default value is 1. Used for 3d plugin. +* Sequence parallelism size: `--sp`. SP size for 3d parallelism. The default value is 1. Used for 3d plugin. +* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1. Used for 3d plugin. +* Sequence parallelism mode: `--sp_mode`. SP mode, used for 3d plugin. Choose from "split_gather", "ring", "all_to_all". +* Switch for sequence parallelism: `--enable_sequence_parallelism`. Whether to enable SP, used for 3d plugin. +* Zero CPU offload: `--zero_cpu_offload`. Whether to use offloading, used for 3d plugin. +* Micro batch size: `--microbatch_size`. Batch size for each process in PP, used for 3d plugin. +* Number of dummy sample: `--num_samples`. Number of samples for benchmarking. +* Benchmark switch: `--benchmark`. Benchmark performance using random dataset. ##### 4.2 Arguments for Supervised Fine-tuning We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining). diff --git a/applications/Colossal-LLaMA/colossal_llama/dataset/dummy_dataset.py b/applications/Colossal-LLaMA/colossal_llama/dataset/dummy_dataset.py new file mode 100644 index 000000000000..3175159fcd37 --- /dev/null +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/dummy_dataset.py @@ -0,0 +1,24 @@ +import torch +from torch.utils.data import Dataset + +from colossalai.accelerator import get_accelerator + + +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } diff --git a/applications/Colossal-LLaMA/colossal_llama/utils/flash_attention_patch.py b/applications/Colossal-LLaMA/colossal_llama/utils/flash_attention_patch.py deleted file mode 100644 index 6c048c3b18cf..000000000000 --- a/applications/Colossal-LLaMA/colossal_llama/utils/flash_attention_patch.py +++ /dev/null @@ -1,352 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import math -from types import MethodType -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaForCausalLM, - LlamaModel, - LlamaRMSNorm, - apply_rotary_pos_emb, - repeat_kv, -) - -from colossalai.accelerator import get_accelerator -from colossalai.logging import get_dist_logger - -logger = get_dist_logger() - -if get_accelerator().name == "cuda": - from flash_attn.bert_padding import pad_input, unpad_input - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func - from flash_attn.ops.rms_norm import rms_norm - - def _prepare_decoder_attention_mask( - self: LlamaModel, - attention_mask: torch.BoolTensor, - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - ) -> Optional[torch.Tensor]: - """ - Decoder attetion mask - """ - if past_key_values_length > 0 and attention_mask is not None: - attention_mask = torch.cat( - tensors=( - torch.full( - size=(input_shape[0], past_key_values_length), - fill_value=True, - dtype=attention_mask.dtype, - device=attention_mask.device, - ), - attention_mask, - ), - dim=-1, - ) # (bsz, past_key_values_length + q_len) - if attention_mask is not None and torch.all(attention_mask): - return None # Faster - return attention_mask - - def attention_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. - """ - if output_attentions: - logger.warning( - "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " - "return `None` instead." - ) - - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - q_slicing, kv_slicing = ( - dim // self.config.pretraining_tp - for dim in ( - self.num_heads * self.head_dim, - self.num_key_value_heads * self.head_dim, - ) - ) # `Tuple[int, int]` - q_slices, k_slices, v_slices = ( - proj.weight.split(slicing, dim=0) - for proj, slicing in ( - (self.q_proj, q_slicing), - (self.k_proj, kv_slicing), - (self.v_proj, kv_slicing), - ) - ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] - q, k, v = ( - torch.cat( - [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], - dim=-1, - ) - for slices in (q_slices, k_slices, v_slices) - ) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) - else: - q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) - - # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) - q, k, v = ( - states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) - for states, num_heads in ( - (q, self.num_heads), - (k, self.num_key_value_heads), - (v, self.num_key_value_heads), - ) - ) - kv_len = k.shape[-2] # initially, `kv_len` == `q_len` - past_kv_len = 0 - if past_key_value is not None: - # if `past_key_value` is not None, `kv_len` > `q_len`. - past_kv_len = past_key_value[0].shape[-2] - kv_len += past_kv_len - - # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) - cos, sin = self.rotary_emb(v, seq_len=kv_len) - # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) - q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) - if past_key_value is not None: - # reuse k, v, self_attention - k = torch.cat([past_key_value[0], k], dim=2) - v = torch.cat([past_key_value[1], v], dim=2) - - past_key_value = (k, v) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - - key_padding_mask = attention_mask - # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) - q, k, v = (states.transpose(1, 2) for states in (q, k, v)) - - if past_kv_len > 0: - q = torch.cat( - tensors=( - torch.full( - size=(bsz, past_kv_len, self.num_heads, self.head_dim), - fill_value=0.0, - dtype=q.dtype, - device=q.device, - ), - q, - ), - dim=1, - ) # (bsz, past_kv_len + q_len, num_heads, head_dim) - - if key_padding_mask is None: - # (bsz, past_kv_len + q_len, num_heads, head_dim) - output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) - output = rearrange( - output, pattern="... h d -> ... (h d)" - ) # (bsz, past_kv_len + q_len, num_heads * head_dim) - else: - q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) - kv, _, cu_kv_lens, max_kv_len = unpad_input( - hidden_states=torch.stack(tensors=(k, v), dim=2), - attention_mask=key_padding_mask, - ) - output_unpad = flash_attn_varlen_kvpacked_func( - q=q, - kv=kv, - cu_seqlens_q=cu_q_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_q_len, - max_seqlen_k=max_kv_len, - dropout_p=0.0, - softmax_scale=None, - causal=True, - ) - output = pad_input( - hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), - indices=indices, - batch=bsz, - seqlen=past_kv_len + q_len, - ) # (bsz, past_kv_len + q_len, num_heads * head_dim) - - if past_kv_len > 0: - # Strip off the zero query outputs. - output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) - output = self.o_proj(output) # (bsz, q_len, hidden_size) - return output, None, past_key_value - - def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Formard function for RMS Norm - """ - return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) - - def replace_with_flash_attention(model: LlamaForCausalLM) -> None: - for name, module in model.named_modules(): - if isinstance(module, LlamaAttention): - module.forward = MethodType(attention_forward, module) - if isinstance(module, LlamaModel): - module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) - if isinstance(module, LlamaRMSNorm): - module.forward = MethodType(rms_norm_forward, module) - -elif get_accelerator().name == "npu": - import torch_npu - - class NPULlamaAttention(LlamaAttention): - use_flash: bool = True - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.setup() - - def setup(self): - self._softmax_scale = 1 / math.sqrt(self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if not self.use_flash: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - else: - attn_output, *_ = torch_npu.npu_fusion_attention( - query_states, - key_states, - value_states, - self.num_heads, - "BNSD", - atten_mask=attention_mask.bool(), - scale=self._softmax_scale, - padding_mask=None, - pre_tockens=65535, - next_tockens=0, - keep_prob=1.0, - inner_precise=0, - ) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum( - [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)] - ) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - class NPURMSNorm(LlamaRMSNorm): - def forward(self, hidden_states): - return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] - - def replace_with_flash_attention(model: LlamaForCausalLM) -> None: - for name, module in model.named_modules(): - if isinstance(module, LlamaAttention): - module.__class__ = NPULlamaAttention - module.setup() - if isinstance(module, LlamaRMSNorm): - module.__class__ = NPURMSNorm diff --git a/applications/Colossal-LLaMA/colossal_llama/utils/utils.py b/applications/Colossal-LLaMA/colossal_llama/utils/utils.py new file mode 100644 index 000000000000..f24ab72c47c9 --- /dev/null +++ b/applications/Colossal-LLaMA/colossal_llama/utils/utils.py @@ -0,0 +1,36 @@ +""" +Utils for Colossal-LLaMA +""" + +import torch +import torch.distributed as dist + +from colossalai.booster import Plugin + + +def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: + if plugin is not None: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group) + tensor.div_(plugin.dp_size) + else: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def get_model_numel(model: torch.nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" diff --git a/applications/Colossal-LLaMA/prepare_pretrain_dataset.py b/applications/Colossal-LLaMA/dataset/prepare_pretrain_dataset.py similarity index 100% rename from applications/Colossal-LLaMA/prepare_pretrain_dataset.py rename to applications/Colossal-LLaMA/dataset/prepare_pretrain_dataset.py diff --git a/applications/Colossal-LLaMA/prepare_sft_dataset.py b/applications/Colossal-LLaMA/dataset/prepare_sft_dataset.py similarity index 100% rename from applications/Colossal-LLaMA/prepare_sft_dataset.py rename to applications/Colossal-LLaMA/dataset/prepare_sft_dataset.py diff --git a/applications/Colossal-LLaMA/inference_example.py b/applications/Colossal-LLaMA/inference/inference_example.py similarity index 100% rename from applications/Colossal-LLaMA/inference_example.py rename to applications/Colossal-LLaMA/inference/inference_example.py diff --git a/applications/Colossal-LLaMA/stream_chat_example.py b/applications/Colossal-LLaMA/inference/stream_chat_example.py similarity index 100% rename from applications/Colossal-LLaMA/stream_chat_example.py rename to applications/Colossal-LLaMA/inference/stream_chat_example.py diff --git a/applications/Colossal-LLaMA/requirements.txt b/applications/Colossal-LLaMA/requirements.txt index 809a942ac398..5b62926f616d 100644 --- a/applications/Colossal-LLaMA/requirements.txt +++ b/applications/Colossal-LLaMA/requirements.txt @@ -1,15 +1,15 @@ torch==2.1.2 huggingface-hub packaging==24.0 -colossalai==0.3.6 +colossalai>=0.4.0 autoflake==2.2.1 black==23.9.1 -transformers==4.34.1 +transformers>=4.39.3 tensorboard==2.14.0 six==1.16.0 datasets ninja==1.11.1 -flash-attn>=2.0.0,<=2.0.5 +flash-attn tqdm sentencepiece==0.1.99 protobuf<=3.20.0 diff --git a/applications/Colossal-LLaMA/setup.py b/applications/Colossal-LLaMA/setup.py new file mode 100644 index 000000000000..c9ba31698218 --- /dev/null +++ b/applications/Colossal-LLaMA/setup.py @@ -0,0 +1,37 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open("README.md", encoding="utf-8") as f: + return f.read() + + +def fetch_version(): + with open("version.txt", "r") as f: + return f.read().strip() + + +setup( + name="colossal_llama", + version=fetch_version(), + packages=find_packages(exclude=("*.egg-info",)), + description="Continual Pre-training and SFT for LLaMA", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.7", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], +) diff --git a/applications/Colossal-LLaMA/train.example.sh b/applications/Colossal-LLaMA/train.example.sh index 6a1c887bf6cc..b795e8bcf810 100644 --- a/applications/Colossal-LLaMA/train.example.sh +++ b/applications/Colossal-LLaMA/train.example.sh @@ -1,13 +1,20 @@ #!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} -# NCCL IB environment variables -export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1 -export NCCL_IB_DISABLE=0 -export NCCL_SOCKET_IFNAME=eth0 -export NCCL_IB_GID_INDEX=3 -export NCCL_IB_TIMEOUT=23 -export NCCL_IB_RETRY_CNT=7 -export OMP_NUM_THREADS=8 +set_n_least_used_CUDA_VISIBLE_DEVICES 8 PROJECT_NAME="" PARENT_SAVE_DIR="" diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index e74aad33c3e3..112a1e0dc223 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -11,24 +11,24 @@ from contextlib import nullcontext import torch -import torch.distributed as dist +from colossal_llama.dataset.dummy_dataset import RandomDataset from colossal_llama.dataset.loader import ( DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset, ) from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama.utils.flash_attention_patch import replace_with_flash_attention from colossal_llama.utils.froze import freeze_non_embeds_parameters from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune +from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import AutoTokenizer, LlamaForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer import colossalai from colossalai.accelerator import get_accelerator from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -36,109 +36,7 @@ from colossalai.utils import get_current_device -def get_model_numel(model: torch.nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) - tensor = tensor.data - tensor.div_(dist.get_world_size()) - return tensor - - -def main() -> None: - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument( - "--pretrained", - type=str, - default=None, - help="Address of the pre-trained modeling", - ) - parser.add_argument("--dataset", nargs="+", default=[]) - parser.add_argument( - "--plugin", - type=str, - default="gemini", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], - help="Choose which plugin to use", - ) - parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") - parser.add_argument("--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") - parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") - parser.add_argument("--config_file", type=str, default="config_file", help="Config file") - parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") - parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") - parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("--max_length", type=int, default=8192, help="Model max length") - parser.add_argument( - "--mixed_precision", - type=str, - default="fp16", - choices=["fp16", "bf16"], - help="Mixed precision", - ) - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") - parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") - parser.add_argument( - "--use_grad_checkpoint", - action="store_true", - default=False, - help="Use gradient checkpointing", - ) - parser.add_argument( - "--use_flash_attn", - action="store_true", - default=False, - help="Use flash-attention", - ) - parser.add_argument( - "--use_neft", - action="store_true", - default=False, - help="Use NEFTune", - ) - parser.add_argument( - "--freeze_non_embeds_params", - action="store_true", - default=False, - help="Freeze non embeddings parameters", - ) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--zero", type=int, default=1) - parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos") - parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length") - parser.add_argument( - "--skip_save_each_epoch", - action="store_true", - default=False, - help="skip saving the model checkpoint after each epoch is completed.", - ) - args = parser.parse_args() - - with open(args.config_file, "w") as f: - json.dump(args.__dict__, f, indent=4) - +def train(args) -> None: # ============================== # Initialize Distributed Training # ============================== @@ -147,21 +45,27 @@ def main() -> None: coordinator = DistCoordinator() # ============================== - # Initialize Tensorboard + # Initialize Tensorboard and Save Config # ============================== if coordinator.is_master(): os.makedirs(args.tensorboard_dir, exist_ok=True) writer = SummaryWriter(args.tensorboard_dir) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + # ============================== # Initialize Booster # ============================== - if args.plugin == "gemini": + if args.plugin == "ddp": + plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False) + elif args.plugin == "gemini": plugin = GeminiPlugin( precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip, enable_gradient_accumulation=(args.accumulation_steps > 1), + enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -170,6 +74,7 @@ def main() -> None: initial_scale=2**16, max_norm=args.grad_clip, enable_gradient_accumulation=(args.accumulation_steps > 1), + enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin( @@ -189,10 +94,17 @@ def main() -> None: elif args.plugin == "3d": plugin = HybridParallelPlugin( tp_size=args.tp, - pp_size=1, - zero_stage=args.zero, + pp_size=args.pp, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + zero_stage=args.zero_stage, + enable_flash_attention=args.use_flash_attn, + enable_sequence_parallelism=args.enable_sequence_parallelism, + cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, + parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, + microbatch_size=args.microbatch_size, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -210,24 +122,38 @@ def main() -> None: tokenizer.add_bos_token = False tokenizer.add_eos_token = False - coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") - coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}") - coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}") + coordinator.print_on_master( + f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}" + ) - coordinator.print_on_master(f"Load dataset: {args.dataset}") + if args.benchmark: + coordinator.print_on_master(f"Run benchmark with {args.num_samples} random samples.") + dataset = RandomDataset( + num_samples=args.num_samples, max_length=args.max_length, vocab_size=tokenizer.vocab_size + ) + dataloader = plugin.prepare_dataloader( + dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + seed=42, + distributed_sampler_cls=StatefulDistributedSampler, + ) + else: + coordinator.print_on_master(f"Load dataset: {args.dataset}") + dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") + data_collator = DataCollatorForSupervisedDataset( + tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode + ) + dataloader = plugin.prepare_dataloader( + dataset=dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) - dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") - data_collator = DataCollatorForSupervisedDataset( - tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode - ) - dataloader = plugin.prepare_dataloader( - dataset=dataset, - batch_size=args.micro_batch_size, - shuffle=True, - drop_last=True, - collate_fn=data_collator, - distributed_sampler_cls=StatefulDistributedSampler, - ) coordinator.print_on_master( f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" ) @@ -241,7 +167,19 @@ def main() -> None: else nullcontext() ) with init_ctx: - model = LlamaForCausalLM.from_pretrained(args.pretrained) + if args.use_flash_attn: + model = AutoModelForCausalLM.from_pretrained( + args.pretrained, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + trust_remote_code=True, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + args.pretrained, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + trust_remote_code=True, + ) # Freeze part of parameters. if args.freeze_non_embeds_params: freeze_non_embeds_parameters(model=model) @@ -251,9 +189,6 @@ def main() -> None: if args.use_grad_checkpoint: model.gradient_checkpointing_enable() coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") - if args.use_flash_attn: - replace_with_flash_attention(model=model) - coordinator.print_on_master(msg="Flash-attention enabled successfully") model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") @@ -342,43 +277,98 @@ def main() -> None: for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch=epoch) - pbar = tqdm( - desc=f"Epoch {epoch}", - disable=not coordinator.is_master(), - total=num_steps_per_epoch, - initial=start_step // args.accumulation_steps, - ) - total_loss = torch.tensor(0.0, device=get_current_device()) - for step, batch in enumerate(dataloader, start=start_step): - batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} - - batch_output = model(**batch) - - loss = batch_output.loss / args.accumulation_steps - total_loss.add_(loss.data) - - booster.backward(loss=loss, optimizer=optimizer) - - if (step + 1) % args.accumulation_steps == 0: + if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1: + data_iter = iter(dataloader) + step_bar = tqdm( + range(len(dataloader)), + desc="Step", + disable=not (coordinator._local_rank == coordinator._world_size - 1), + ) + for step in step_bar: + outputs = booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if booster.plugin.stage_manager.is_last_stage(): + global_loss = all_reduce_mean(loss, plugin) + if coordinator._local_rank == coordinator._world_size - 1: + step_bar.set_postfix({"train/loss": global_loss.item()}) optimizer.step() - lr_scheduler.step() optimizer.zero_grad() - all_reduce_mean(tensor=total_loss) - pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) - if coordinator.is_master(): - global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps - writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) - writer.add_scalar( - tag="Learning Rate", - scalar_value=lr_scheduler.get_last_lr()[0], - global_step=global_step, + # Save modeling. + save_model_condition = args.save_interval > 0 and (step + 1) % args.save_interval == 0 + + if not args.skip_save_each_epoch: + save_model_condition = save_model_condition or (step + 1) == len(dataloader) + + if save_model_condition and not args.benchmark: + coordinator.print_on_master("\nStart saving model checkpoint with running states") + + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune before saving model.") + deactivate_neftune(model, handle) + + accelerator.empty_cache() + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.batch_size, + coordinator=coordinator, + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" ) - total_loss.fill_(0.0) - pbar.update() - # Save modeling. + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + else: + pbar = tqdm( + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + total=num_steps_per_epoch, + initial=start_step // args.accumulation_steps, + ) + total_loss = torch.tensor(0.0, device=get_current_device()) + for step, batch in enumerate(dataloader, start=start_step): + batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} + + batch_output = model(**batch) + + loss = batch_output.loss / args.accumulation_steps + total_loss.add_(loss.data) + + booster.backward(loss=loss, optimizer=optimizer) + + if (step + 1) % args.accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + all_reduce_mean(tensor=total_loss) + pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) + if coordinator.is_master(): + global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps + writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) + writer.add_scalar( + tag="Learning Rate", + scalar_value=lr_scheduler.get_last_lr()[0], + global_step=global_step, + ) + total_loss.fill_(0.0) + pbar.update() + # Save modeling. save_model_condition = ( args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 ) @@ -386,7 +376,7 @@ def main() -> None: if not args.skip_save_each_epoch: save_model_condition = save_model_condition or (step + 1) == len(dataloader) - if save_model_condition: + if save_model_condition and not args.benchmark: coordinator.print_on_master("\nStart saving model checkpoint with running states") if args.use_neft: @@ -402,7 +392,7 @@ def main() -> None: lr_scheduler=lr_scheduler, epoch=epoch, step=step + 1, - batch_size=args.micro_batch_size, + batch_size=args.batch_size, coordinator=coordinator, ) coordinator.print_on_master( @@ -426,12 +416,114 @@ def main() -> None: deactivate_neftune(model, handle) # Final save. - coordinator.print_on_master("Start saving final model checkpoint") - booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") + if not args.benchmark: + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + # Basic training information. + parser.add_argument( + "--pretrained", + type=str, + default=None, + help="Address of the pre-trained model", + ) + parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint for continuous training.") + parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"], + help="Choose which plugin to use", + ) + parser.add_argument("--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") + parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") + parser.add_argument("--config_file", type=str, default="config_file", help="Config file") + # Training parameters + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") + parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("--max_length", type=int, default=8192, help="Model max length") + parser.add_argument( + "--mixed_precision", + type=str, + default="fp16", + choices=["fp16", "bf16"], + help="Mixed precision", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument( + "--use_grad_checkpoint", + action="store_true", + default=False, + help="Use gradient checkpointing", + ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + default=False, + help="Use flash-attention", + ) + parser.add_argument( + "--use_neft", + action="store_true", + default=False, + help="Use NEFTune", + ) + parser.add_argument( + "--freeze_non_embeds_params", + action="store_true", + default=False, + help="Freeze non embeddings parameters", + ) + parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos") + parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length") + parser.add_argument( + "--skip_save_each_epoch", + action="store_true", + default=False, + help="Skip saving the model checkpoint after each epoch is completed.", + ) + + # Additional arguments for 3d plugin. + parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.") + parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.") + parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.") + parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2]) + parser.add_argument( + "--sp_mode", + type=str, + default="split_gather", + choices=["split_gather", "ring", "all_to_all"], + help="SP mode, used for 3d plugin.", + ) + parser.add_argument( + "--enable_sequence_parallelism", + default=False, + action="store_true", + help="Whether to enable SP, used for 3d plugin.", + ) + parser.add_argument( + "--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin." + ) + parser.add_argument( + "--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin." + ) + + # Additional arguments for benchmark. + parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.") + parser.add_argument( + "--benchmark", action="store_true", default=False, help="Benchmark performance using random dataset." + ) + args = parser.parse_args() + train(args) diff --git a/applications/Colossal-LLaMA/version.txt b/applications/Colossal-LLaMA/version.txt index 3eefcb9dd5b3..9084fa2f716a 100644 --- a/applications/Colossal-LLaMA/version.txt +++ b/applications/Colossal-LLaMA/version.txt @@ -1 +1 @@ -1.0.0 +1.1.0