Skip to content

Commit

Permalink
Contribute Ascend NPU Compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
sunyiran committed May 9, 2024
1 parent c6cc021 commit 69121cf
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 52 deletions.
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,45 @@ Run the following command to start the docker container in interactive mode.
docker run -ti --gpus all -v {MOUNT_DIR}:/data opensora
```

## Installation_NPU

1. create a virtual env
```bash
# create a virtual env
conda create -n opensora python=3.10
# activate virtual environment
conda activate opensora

# install torch
pip install torch==2.1.0 torchvision==0.16.0
```
2. create NPU env

Please refer to 《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》。
```bash
# activate cann env
source ${cann_install_path}/ascend-toolkit/set_env.sh

# install this project
git clone https://github.com/hpcaitech/Open-Sora
cd Open-Sora
pip install -v .
```

### Use Docker

Run the following command to build a docker image from Dockerfile provided.

```bash
docker build -t opensora ./docker
```

Run the following command to start the docker container in interactive mode.

```bash
docker run -ti --gpus all -v {MOUNT_DIR}:/data opensora
```

## Model Weights

### Open-Sora 1.1 Model Weights
Expand Down
31 changes: 31 additions & 0 deletions docs/zh_CN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
## 目录

* [安装](#安装)
* [安装昇腾NPU环境](#安装昇腾npu环境)
* [模型权重](#模型权重)
* [推理](#推理)
* [数据处理](#数据处理)
Expand Down Expand Up @@ -130,6 +131,36 @@ docker run -ti --gpus all -v {MOUNT_DIR}:/data opensora

安装完成后,建议阅读[结构](structure.md),了解项目结构以及如何使用配置文件。


## 安装昇腾NPU环境

1. 创建虚拟环境
```bash
# create a virtual env
conda create -n opensora python=3.10
# activate virtual environment
conda activate opensora

# install torch
pip install torch==2.1.0 torchvision==0.16.0
```

2. 准备昇腾NPU环境


请参考昇腾社区中《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》文档搭建昇腾环境。

```bash
# activate cann env
source ${cann_install_path}/ascend-toolkit/set_env.sh

# install this project
git clone https://github.com/hpcaitech/Open-Sora
cd Open-Sora
pip install -v .
```


## 模型权重

| 分辨率 | 数据 | 迭代次数 | 批量大小 | GPU 天数 (H800) | 网址 |
Expand Down
32 changes: 26 additions & 6 deletions opensora/acceleration/shardformer/modeling/t5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
import torch
import torch.nn as nn
from opensora.utils.device_utils import is_npu_available
if is_npu_available():
import torch_npu


class NpuRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Initialize NPU RMSNorm normalization layer
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps

def forward(self, x):
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]


class T5LayerNorm(nn.Module):
Expand Down Expand Up @@ -28,12 +44,16 @@ def forward(self, hidden_states):

@staticmethod
def from_native_module(module, *args, **kwargs):
assert module.__class__.__name__ == "FusedRMSNorm", (
"Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
"Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
)

layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
if is_npu_available():
normalized_shape = module.weight.shape[0]
layer_norm = NpuRMSNorm(normalized_shape, eps=module.variance_epsilon)
else:
assert module.__class__.__name__ == "FusedRMSNorm", (
"Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
"Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
)

layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
layer_norm.weight.data.copy_(module.weight.data)
layer_norm = layer_norm.to(module.weight.device)
return layer_norm
161 changes: 123 additions & 38 deletions opensora/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import xformers.ops
from einops import rearrange
from timm.models.vision_transformer import Mlp

from opensora.acceleration.communications import all_to_all, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
from opensora.utils.device_utils import is_npu_available
if not is_npu_available():
import xformers.ops
else:
import torch_npu


approx_gelu = lambda: nn.GELU(approximate="tanh")

Expand Down Expand Up @@ -178,19 +183,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
q, k = self.q_norm(q), self.k_norm(k)

if enable_flash_attn:
from flash_attn import flash_attn_func

# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
if is_npu_available() and q.dtype in [torch.float16, torch.bfloat16]:
x = torch_npu.npu_fusion_attention(
q, k, v, self.num_heads, input_layout="BNSD",
pse=None,
scale=self.scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1. - self.attn_drop.p if self.training else 1.,
sync=False,
inner_precise=0,
)[0]
x = x.transpose(1, 2)
else:
from flash_attn import flash_attn_func

# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
else:
dtype = q.dtype
q = q * self.scale
Expand Down Expand Up @@ -270,15 +288,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.enable_flash_attn:
from flash_attn import flash_attn_func

x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
if is_npu_available() and q.dtype in [torch.float16, torch.bfloat16]:
x = torch_npu.npu_fusion_attention(
q, k, v, q.shape[-2], input_layout="BSND",
pse=None,
scale=self.scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.-self.attn_drop.p if self.training else 1.,
sync=False,
inner_precise=0,
)[0]
else:
from flash_attn import flash_attn_func

x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
else:
dtype = q.dtype
q = q * self.scale
Expand Down Expand Up @@ -323,14 +353,42 @@ def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape

q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
if is_npu_available() and x.dtype in [torch.float16, torch.bfloat16]:
q = self.q_linear(x).view(-1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(-1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(1)

actual_seq_qlen = []
actual_seq_kvlen = []
if mask is not None:
ans = 0
for _ in range(B):
ans += N
actual_seq_qlen.append(ans)
ans = 0
for m in mask:
ans += m
actual_seq_kvlen.append(ans)
x = torch_npu.npu_fusion_attention(
q, k, v, self.num_heads, input_layout="TND",
pse=None,
scale=1.0 / math.sqrt(self.head_dim),
pre_tockens=65536,
next_tockens=65536,
actual_seq_qlen=tuple(actual_seq_qlen),
actual_seq_kvlen=tuple(actual_seq_kvlen),
keep_prob=1. - self.attn_drop.p,
sparse_mode=0,
)[0]
else:
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)

attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)

x = x.view(B, -1, C)
x = self.proj(x)
Expand Down Expand Up @@ -372,15 +430,42 @@ def forward(self, x, cond, mask=None):
k = split_forward_gather_backward(k, get_sequence_parallel_group(), dim=2, grad_scale="down")
v = split_forward_gather_backward(v, get_sequence_parallel_group(), dim=2, grad_scale="down")

q = q.view(1, -1, self.num_heads // sp_size, self.head_dim)
k = k.view(1, -1, self.num_heads // sp_size, self.head_dim)
v = v.view(1, -1, self.num_heads // sp_size, self.head_dim)

# compute attention
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
if is_npu_available() and q.dtype in [torch.float16, torch.bfloat16]:
q = q.view(-1, self.num_heads // sp_size, self.head_dim)
k = k.view(-1, self.num_heads // sp_size, self.head_dim)
v = v.view(-1, self.num_heads // sp_size, self.head_dim)

actual_seq_qlen = []
actual_seq_kvlen = []
if mask is not None:
ans = 0
for _ in range(B):
ans += N
actual_seq_qlen.append(ans)
ans = 0
for m in mask:
ans += m
actual_seq_kvlen.append(ans)
x = torch_npu.npu_fusion_attention(
q, k, v, q.shape[-2], input_layout="TND",
pse=None,
scale=1.0 / math.sqrt(self.head_dim),
pre_tockens=65536,
next_tockens=65536,
actual_seq_qlen=tuple(actual_seq_qlen),
actual_seq_kvlen=tuple(actual_seq_kvlen),
keep_prob=1. - self.attn_drop.p,
sparse_mode=0,
)[0]
else:
q = q.view(1, -1, self.num_heads // sp_size, self.head_dim)
k = k.view(1, -1, self.num_heads // sp_size, self.head_dim)
v = v.view(1, -1, self.num_heads // sp_size, self.head_dim)
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)

# apply all to all to gather back attention heads and scatter sequence
x = x.view(B, -1, self.num_heads // sp_size, self.head_dim)
Expand Down
17 changes: 17 additions & 0 deletions opensora/utils/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
import importlib


def is_npu_available():
"Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
return False

import torch_npu

try:
# Will raise a RuntimeError if no NPU is found
_ = torch.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
20 changes: 20 additions & 0 deletions requirements_npu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
colossalai
accelerate
diffusers
ftfy
gdown
mmengine
pandas
pre-commit
pyarrow
av
tensorboard
timm
tqdm
transformers
wandb
rotary_embedding_torch
pandarallel
scipy
decorator
attrs
4 changes: 4 additions & 0 deletions scripts/inference-long.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
from opensora.utils.device_utils import is_npu_available
if is_npu_available():
from torch_npu.contrib import transfer_to_npu
torch.npu.config.allow_internal_format = False


def collect_references_batch(reference_paths, vae, image_size):
Expand Down
Loading

0 comments on commit 69121cf

Please sign in to comment.