From 65c61d0d5dd12a1cb7649d00caa65c2ce6600cc9 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Fri, 21 Jun 2024 13:09:10 +0800 Subject: [PATCH 01/12] [doc] add v1.2 blog (#517) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 10ef22bf..44ea1e5b 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi ## 📰 News -- **[2024.06.17]** 🔥 We released **Open-Sora 1.2**, which includes **3D-VAE**, **rectified flow**, and **score condition**. The video quality is greatly improved. [[checkpoints]](#open-sora-10-model-weights) [[report]](/docs/report_03.md) +- **[2024.06.17]** 🔥 We released **Open-Sora 1.2**, which includes **3D-VAE**, **rectified flow**, and **score condition**. The video quality is greatly improved. [[checkpoints]](#open-sora-10-model-weights) [[report]](/docs/report_03.md) [[blog]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use) - **[2024.04.25]** 🤗 We released the [Gradio demo for Open-Sora](https://huggingface.co/spaces/hpcai-tech/open-sora) on Hugging Face Spaces. - **[2024.04.25]** We released **Open-Sora 1.1**, which supports **2s~15s, 144p to 720p, any aspect ratio** text-to-image, **text-to-video, image-to-video, video-to-video, infinite time** generation. In addition, a full video processing pipeline is released. [[checkpoints]]() [[report]](/docs/report_02.md) - **[2024.03.18]** We released **Open-Sora 1.0**, a fully open-source project for video generation. From 91ccddc01caeeb7a70304c11fc4c42a4237fc9cf Mon Sep 17 00:00:00 2001 From: HangXu Date: Fri, 21 Jun 2024 11:15:39 +0300 Subject: [PATCH 02/12] Force fp16 input to fp32 to avoid nan output in timestep_transform --- opensora/schedulers/rf/rectified_flow.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py index 58d7b486..8acaff5d 100644 --- a/opensora/schedulers/rf/rectified_flow.py +++ b/opensora/schedulers/rf/rectified_flow.py @@ -15,6 +15,11 @@ def timestep_transform( scale=1.0, num_timesteps=1, ): + # Force fp16 input to fp32 to avoid nan output + for key in ["height", "width", "num_frames"]: + if model_kwargs[key].dtype == torch.float16: + model_kwargs[key] = model_kwargs[key].float() + t = t / num_timesteps resolution = model_kwargs["height"] * model_kwargs["width"] ratio_space = (resolution / base_resolution).sqrt() From e8ad745a9ccacbefaa8ca2462d76164d14389458 Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Fri, 21 Jun 2024 16:27:30 +0800 Subject: [PATCH 03/12] fix broken links in cn report v1 --- docs/zh_CN/report_v1.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/zh_CN/report_v1.md b/docs/zh_CN/report_v1.md index bf12131a..feedda37 100644 --- a/docs/zh_CN/report_v1.md +++ b/docs/zh_CN/report_v1.md @@ -11,11 +11,11 @@ OpenAI的Sora在生成一分钟高质量视频方面非常出色。然而,它 如图中所示,在STDiT(ST代表时空)中,我们在每个空间注意力之后立即插入一个时间注意力。这类似于Latte论文中的变种3。然而,我们并没有控制这些变体的相似数量的参数。虽然Latte的论文声称他们的变体比变种3更好,但我们在16x256x256视频上的实验表明,相同数量的迭代次数下,性能排名为:DiT(完整)> STDiT(顺序)> STDiT(并行)≈ Latte。因此,我们出于效率考虑选择了STDiT(顺序)。[这里](/docs/acceleration.md#efficient-stdit)提供了速度基准测试。 -![Architecture Comparison](https://i0.imgs.ovh/2024/03/15/eLk9D.png) +![Architecture Comparison](/assets/readme/report_arch_comp.png) 为了专注于视频生成,我们希望基于一个强大的图像生成模型来训练我们的模型。PixArt-α是一个经过高效训练的高质量图像生成模型,具有T5条件化的DiT结构。我们使用[PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha)初始化我们的模型,并将插入的时间注意力的投影层初始化为零。这种初始化在开始时保留了模型的图像生成能力,而Latte的架构则不能。插入的注意力将参数数量从5.8亿增加到7.24亿。 -![Architecture](https://i0.imgs.ovh/2024/03/16/erC1d.png) +![Architecture](/assets/readme/report_arch.jpg) 借鉴PixArt-α和Stable Video Diffusion的成功,我们还采用了渐进式训练策略:在366K预训练数据集上进行16x256x256的训练,然后在20K数据集上进行16x256x256、16x512x512和64x512x512的训练。通过扩展位置嵌入,这一策略极大地降低了计算成本。 @@ -26,7 +26,7 @@ OpenAI的Sora在生成一分钟高质量视频方面非常出色。然而,它 我们发现数据的数量和质量对生成视频的质量有很大的影响,甚至比模型架构和训练策略的影响还要大。目前,我们只从[HD-VG-130M](https://github.com/daooshee/HD-VG-130M)准备了第一批分割(366K个视频片段)。这些视频的质量参差不齐,而且字幕也不够准确。因此,我们进一步从提供免费许可视频的[Pexels](https://www.pexels.com/)收集了20k相对高质量的视频。我们使用LLaVA,一个图像字幕模型,通过三个帧和一个设计好的提示来标记视频。有了设计好的提示,LLaVA能够生成高质量的字幕。 -![Caption](https://i0.imgs.ovh/2024/03/16/eXdvC.png) +![Caption](/assets/readme/report_caption.png) 由于我们更加注重数据质量,我们准备收集更多数据,并在下一版本中构建一个视频预处理流程。 @@ -38,12 +38,12 @@ OpenAI的Sora在生成一分钟高质量视频方面非常出色。然而,它 16x256x256 预训练损失曲线 -![16x256x256 Pretraining Loss Curve](https://i0.imgs.ovh/2024/03/16/erXQj.png) +![16x256x256 Pretraining Loss Curve](/assets/readme/report_loss_curve_1.png) 16x256x256 高质量训练损失曲线 -![16x256x256 HQ Training Loss Curve](https://i0.imgs.ovh/2024/03/16/ernXv.png) +![16x256x256 HQ Training Loss Curve](/assets/readme/report_loss_curve_2.png) 16x512x512 高质量训练损失曲线 -![16x512x512 HQ Training Loss Curve](https://i0.imgs.ovh/2024/03/16/erHBe.png) +![16x512x512 HQ Training Loss Curve](/assets/readme/report_loss_curve_3.png) From 019d3de55ab4928610cd1191b5b7e94a26c7df25 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Fri, 21 Jun 2024 18:23:30 +0000 Subject: [PATCH 04/12] [feat] reduce memory leakage in dataloader and pyav --- configs/opensora-v1-2/train/demo_360p.py | 58 ++++ .../train/{stage3_480p.py => demo_480p.py} | 17 +- opensora/datasets/dataloader.py | 4 + opensora/datasets/datasets.py | 6 +- opensora/datasets/read_video.py | 252 +++++++++++------- scripts/train.py | 2 +- tools/architecture/__init__.py | 0 tools/architecture/net2net.py | 73 ----- 8 files changed, 226 insertions(+), 186 deletions(-) create mode 100644 configs/opensora-v1-2/train/demo_360p.py rename configs/opensora-v1-2/train/{stage3_480p.py => demo_480p.py} (76%) delete mode 100644 tools/architecture/__init__.py delete mode 100644 tools/architecture/net2net.py diff --git a/configs/opensora-v1-2/train/demo_360p.py b/configs/opensora-v1-2/train/demo_360p.py new file mode 100644 index 00000000..e27bd3cd --- /dev/null +++ b/configs/opensora-v1-2/train/demo_360p.py @@ -0,0 +1,58 @@ +# Dataset settings +dataset = dict( + type="VariableVideoTextDataset", + transform_name="resize_crop", +) + +# webvid +bucket_config = {"360p": {102: (1.0, 5)}} +grad_checkpoint = True + +# Acceleration settings +num_workers = 8 +num_bucket_build_workers = 16 +dtype = "bf16" +plugin = "zero2" + +# Model settings +model = dict( + type="STDiT3-XL/2", + from_pretrained=None, + qk_norm=True, + enable_flash_attn=True, + enable_layernorm_kernel=True, + freeze_y_embedder=True, +) +vae = dict( + type="OpenSoraVAE_V1_2", + from_pretrained="hpcai-tech/OpenSora-VAE-v1.2", + micro_frame_size=17, + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=300, + shardformer=True, +) +scheduler = dict( + type="rflow", + use_timestep_transform=True, + sample_method="logit-normal", +) + +# Log settings +seed = 42 +outputs = "outputs" +wandb = False +epochs = 1000 +log_every = 10 +ckpt_every = 200 + +# optimization settings +load = None +grad_clip = 1.0 +lr = 1e-4 +ema_decay = 0.99 +adam_eps = 1e-15 +warmup_steps = 1000 diff --git a/configs/opensora-v1-2/train/stage3_480p.py b/configs/opensora-v1-2/train/demo_480p.py similarity index 76% rename from configs/opensora-v1-2/train/stage3_480p.py rename to configs/opensora-v1-2/train/demo_480p.py index b4b9ffdb..08121c7b 100644 --- a/configs/opensora-v1-2/train/stage3_480p.py +++ b/configs/opensora-v1-2/train/demo_480p.py @@ -9,7 +9,7 @@ grad_checkpoint = True # Acceleration settings -num_workers = 0 +num_workers = 8 num_bucket_build_workers = 16 dtype = "bf16" plugin = "zero2" @@ -41,21 +41,6 @@ sample_method="logit-normal", ) -# Mask settings -# 25% -mask_ratios = { - "random": 0.01, - "intepolate": 0.002, - "quarter_random": 0.002, - "quarter_head": 0.002, - "quarter_tail": 0.002, - "quarter_head_tail": 0.002, - "image_random": 0.0, - "image_head": 0.22, - "image_tail": 0.005, - "image_head_tail": 0.005, -} - # Log settings seed = 42 outputs = "outputs" diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py index 8bcaed95..15058ac8 100644 --- a/opensora/datasets/dataloader.py +++ b/opensora/datasets/dataloader.py @@ -34,6 +34,7 @@ def prepare_dataloader( process_group: Optional[ProcessGroup] = None, bucket_config=None, num_bucket_build_workers=1, + prefetch_factor=None, **kwargs, ): _kwargs = kwargs.copy() @@ -57,6 +58,7 @@ def prepare_dataloader( pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn_default, + prefetch_factor=prefetch_factor, **_kwargs, ), batch_sampler, @@ -79,6 +81,7 @@ def prepare_dataloader( pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn_default, + prefetch_factor=prefetch_factor, **_kwargs, ), sampler, @@ -98,6 +101,7 @@ def prepare_dataloader( pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn_batch, + prefetch_factor=prefetch_factor, **_kwargs, ), sampler, diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 34a5dcf6..fcf070a9 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -151,9 +151,11 @@ def getitem(self, index): # Sampling video frames video = temporal_random_crop(vframes, num_frames, self.frame_interval) + video = video.clone() + del vframes video_fps = video_fps // self.frame_interval - + # transform transform = get_transforms_video(self.transform_name, (height, width)) video = transform(video) # T C H W @@ -169,7 +171,7 @@ def getitem(self, index): # repeat video = image.unsqueeze(0) - # TCHW -> CTHW + # # TCHW -> CTHW video = video.permute(1, 0, 2, 3) ret = { "video": video, diff --git a/opensora/datasets/read_video.py b/opensora/datasets/read_video.py index f988c306..ce88f593 100644 --- a/opensora/datasets/read_video.py +++ b/opensora/datasets/read_video.py @@ -1,20 +1,19 @@ import gc import math import os +import re +import warnings from fractions import Fraction -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import av import cv2 import numpy as np import torch -from torchvision.io.video import ( - _align_audio_frames, - _check_av_available, - _log_api_usage_once, - _read_from_stream, - _video_opt, -) +from torchvision import get_video_backend +from torchvision.io.video import _check_av_available + +MAX_NUM_FRAMES = 2500 def read_video_av( @@ -27,6 +26,13 @@ def read_video_av( """ Reads a video from a file, returning both the video frames and the audio frames + This method is modified from torchvision.io.video.read_video, with the following changes: + + 1. will not extract audio frames and return empty for aframes + 2. remove checks and only support pyav + 3. add container.close() and gc.collect() to avoid thread leakage + 4. try our best to avoid memory leak + Args: filename (str): path to the video file start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): @@ -42,99 +48,162 @@ def read_video_av( aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_video) - + # format output_format = output_format.upper() if output_format not in ("THWC", "TCHW"): raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") - - from torchvision import get_video_backend - + # file existence if not os.path.exists(filename): raise RuntimeError(f"File not found: {filename}") - - if get_video_backend() != "pyav": - vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit) - else: - _check_av_available() - - if end_pts is None: - end_pts = float("inf") - - if end_pts < start_pts: - raise ValueError( - f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}" - ) - - info = {} - video_frames = [] - audio_frames = [] - audio_timebase = _video_opt.default_timebase - - container = av.open(filename, metadata_errors="ignore") - try: - if container.streams.audio: - audio_timebase = container.streams.audio[0].time_base - if container.streams.video: - video_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.video[0], - {"video": 0}, - ) - video_fps = container.streams.video[0].average_rate - # guard against potentially corrupted files - if video_fps is not None: - info["video_fps"] = float(video_fps) - - if container.streams.audio: - audio_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.audio[0], - {"audio": 0}, - ) - info["audio_fps"] = container.streams.audio[0].rate - except av.AVError: - # TODO raise a warning? - pass - finally: - container.close() - del container - # NOTE: manually garbage collect to close pyav threads - gc.collect() - - vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] - aframes_list = [frame.to_ndarray() for frame in audio_frames] - - if vframes_list: - vframes = torch.as_tensor(np.stack(vframes_list)) - else: - vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) - - if aframes_list: - aframes = np.concatenate(aframes_list, 1) - aframes = torch.as_tensor(aframes) - if pts_unit == "sec": - start_pts = int(math.floor(start_pts * (1 / audio_timebase))) - if end_pts != float("inf"): - end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) - aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) - else: - aframes = torch.empty((1, 0), dtype=torch.float32) - + # backend check + assert get_video_backend() == "pyav", "pyav backend is required for read_video_av" + _check_av_available() + # end_pts check + if end_pts is None: + end_pts = float("inf") + if end_pts < start_pts: + raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") + + # == get video info == + info = {} + # TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU) + container = av.open(filename, metadata_errors="ignore") + # fps + video_fps = container.streams.video[0].average_rate + # guard against potentially corrupted files + if video_fps is not None: + info["video_fps"] = float(video_fps) + iter_video = container.decode(**{"video": 0}) + frame = next(iter_video).to_rgb().to_ndarray() + height, width = frame.shape[:2] + total_frames = container.streams.video[0].frames + if total_frames == 0: + total_frames = MAX_NUM_FRAMES + warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback") + container.close() + del container + + # HACK: must create before iterating stream + # use np.zeros will not actually allocate memory + # use np.ones will lead to a little memory leak + video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8) + + # == read == + # TODO: The reading has memory leak (4G for 8 workers 1 GPU) + container = av.open(filename, metadata_errors="ignore") + assert container.streams.video is not None + video_frames = _read_from_stream( + video_frames, + container, + start_pts, + end_pts, + pts_unit, + container.streams.video[0], + {"video": 0}, + ) + + vframes = torch.from_numpy(video_frames).clone() + del video_frames if output_format == "TCHW": # [T,H,W,C] --> [T,C,H,W] vframes = vframes.permute(0, 3, 1, 2) + aframes = torch.empty((1, 0), dtype=torch.float32) return vframes, aframes, info +def _read_from_stream( + video_frames, + container: "av.container.Container", + start_offset: float, + end_offset: float, + pts_unit: str, + stream: "av.stream.Stream", + stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], +) -> List["av.frame.Frame"]: + + if pts_unit == "sec": + # TODO: we should change all of this from ground up to simply take + # sec and convert to MS in C++ + start_offset = int(math.floor(start_offset * (1 / stream.time_base))) + if end_offset != float("inf"): + end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) + else: + warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") + + should_buffer = True + max_buffer_size = 5 + if stream.type == "video": + # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) + # so need to buffer some extra frames to sort everything + # properly + extradata = stream.codec_context.extradata + # overly complicated way of finding if `divx_packed` is set, following + # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 + if extradata and b"DivX" in extradata: + # can't use regex directly because of some weird characters sometimes... + pos = extradata.find(b"DivX") + d = extradata[pos:] + o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) + if o is None: + o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) + if o is not None: + should_buffer = o.group(3) == b"p" + seek_offset = start_offset + # some files don't seek to the right location, so better be safe here + seek_offset = max(seek_offset - 1, 0) + if should_buffer: + # FIXME this is kind of a hack, but we will jump to the previous keyframe + # so this will be safe + seek_offset = max(seek_offset - max_buffer_size, 0) + try: + # TODO check if stream needs to always be the video stream here or not + container.seek(seek_offset, any_frame=False, backward=True, stream=stream) + except av.AVError: + # TODO add some warnings in this case + # print("Corrupted file?", container.name) + return [] + + # == main == + buffer_count = 0 + frames_pts = [] + cnt = 0 + for _idx, frame in enumerate(container.decode(**stream_name)): + frames_pts.append(frame.pts) + video_frames[cnt] = frame.to_rgb().to_ndarray() + cnt += 1 + if cnt >= len(video_frames): + break + if frame.pts >= end_offset: + if should_buffer and buffer_count < max_buffer_size: + buffer_count += 1 + continue + break + + # garbage collection for thread leakage + container.close() + del container + # NOTE: manually garbage collect to close pyav threads + gc.collect() + + # ensure that the results are sorted wrt the pts + # NOTE: here we assert frames_pts is sorted + start_ptr = 0 + end_ptr = cnt + while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset: + start_ptr += 1 + while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset: + end_ptr -= 1 + if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]: + # if there is no frame that exactly matches the pts of start_offset + # add the last frame smaller than start_offset, to guarantee that + # we will have all the necessary data. This is most useful for audio + if start_ptr > 0: + start_ptr -= 1 + result = video_frames[start_ptr:end_ptr].copy() + return result + + def read_video_cv2(video_path): cap = cv2.VideoCapture(video_path) @@ -181,8 +250,3 @@ def read_video(video_path, backend="av"): raise ValueError return vframes, vinfo - - -if __name__ == "__main__": - vframes, vinfo = read_video("./data/colors/9.mp4", backend="cv2") - x = 0 diff --git a/scripts/train.py b/scripts/train.py index 2ebdd412..574fc0b3 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -98,6 +98,7 @@ def main(): drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), + prefetch_factor=cfg.get("prefetch_factor", None), ) dataloader, sampler = prepare_dataloader( bucket_config=cfg.get("bucket_config", None), @@ -247,7 +248,6 @@ def main(): with Timer("move data") as move_data_t: x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] y = batch.pop("text") - timer_list.append(move_data_t) # == visual and text encoding == with Timer("encode") as encode_t: diff --git a/tools/architecture/__init__.py b/tools/architecture/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tools/architecture/net2net.py b/tools/architecture/net2net.py deleted file mode 100644 index d5d7eb34..00000000 --- a/tools/architecture/net2net.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -Implementation of Net2Net (http://arxiv.org/abs/1511.05641) -Numpy modules for Net2Net -- Net2Wider -- Net2Deeper - -Written by Kyunghyun Paeng - -""" - - -def net2net(teach_param, stu_param): - # teach param with shape (a, b) - # stu param with shape (c, d) - # net to net (a, b) -> (c, d) where c >= a and d >= b - teach_param_shape = teach_param.shape - stu_param_shape = stu_param.shape - - if len(stu_param_shape) > 2: - teach_param = teach_param.reshape(teach_param_shape[0], -1) - stu_param = stu_param.reshape(stu_param_shape[0], -1) - - assert len(stu_param.shape) == 1 or len(stu_param.shape) == 2, "teach_param and stu_param must be 2-dim array" - assert len(teach_param_shape) == len(stu_param_shape), "teach_param and stu_param must have same dimension" - - if len(teach_param_shape) == 1: - stu_param[: teach_param_shape[0]] = teach_param - elif len(teach_param_shape) == 2: - stu_param[: teach_param_shape[0], : teach_param_shape[1]] = teach_param - else: - breakpoint() - - if stu_param.shape != stu_param_shape: - stu_param = stu_param.reshape(stu_param_shape) - - return stu_param - - -if __name__ == "__main__": - """Net2Net Class Test""" - - import torch - - from opensora.models.pixart import PixArt_1B_2 - - model = PixArt_1B_2(no_temporal_pos_emb=True, space_scale=4, enable_flash_attn=True, enable_layernorm_kernel=True) - print("load model done") - - ckpt = torch.load("/home/zhouyukun/projs/opensora/pretrained_models/PixArt-Sigma-XL-2-2K-MS.pth") - print("load ckpt done") - - ckpt = ckpt["state_dict"] - ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) - - missing_keys = [] - for name, module in model.named_parameters(): - if name in ckpt: - teach_param = ckpt[name].data - stu_param = module.data - stu_param = net2net(teach_param, stu_param) - - module.data = stu_param - - print("processing layer: ", name, "shape: ", module.size()) - - else: - # print("Missing key: ", name) - missing_keys.append(name) - - print(missing_keys) - - breakpoint() - torch.save({"state_dict": model.state_dict()}, "PixArt-1B-2.pth") From 49d5edd99313e2304a8f75adc253f39be1249772 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Fri, 21 Jun 2024 19:03:30 +0000 Subject: [PATCH 05/12] [fix] support stdit1 training --- opensora/models/stdit/stdit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opensora/models/stdit/stdit.py b/opensora/models/stdit/stdit.py index 0bb3d5dd..428ad269 100644 --- a/opensora/models/stdit/stdit.py +++ b/opensora/models/stdit/stdit.py @@ -255,7 +255,7 @@ def __init__( else: self.sp_rank = None - def forward(self, x, timestep, y, mask=None, x_mask=None): + def forward(self, x, timestep, y, mask=None, x_mask=None, **kwargs): """ Forward pass of STDiT. Args: From 22c4707a90aa3af31e6da52f5c53c74c81f19f20 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Fri, 21 Jun 2024 19:21:00 +0000 Subject: [PATCH 06/12] [fix] time list --- scripts/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train.py b/scripts/train.py index 574fc0b3..98bb2d70 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -248,6 +248,7 @@ def main(): with Timer("move data") as move_data_t: x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] y = batch.pop("text") + timer_list.append(move_data_t) # == visual and text encoding == with Timer("encode") as encode_t: From 1b79ec3b4db9b1abee8d176f9d68b3ad3d58c7fc Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Fri, 21 Jun 2024 19:22:02 +0000 Subject: [PATCH 07/12] minor fix --- opensora/datasets/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index fcf070a9..8b5fdd6a 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -171,7 +171,7 @@ def getitem(self, index): # repeat video = image.unsqueeze(0) - # # TCHW -> CTHW + # TCHW -> CTHW video = video.permute(1, 0, 2, 3) ret = { "video": video, From d74ef76b91482653fd285e6865eaf382554a04dd Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Sat, 22 Jun 2024 11:36:13 +0000 Subject: [PATCH 08/12] [docs] update tutorial --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 44ea1e5b..ab7bf35b 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Open-Sora not only democratizes access to advanced video generation techniques, streamlined and user-friendly platform that simplifies the complexities of video generation. With Open-Sora, our goal is to foster innovation, creativity, and inclusivity within the field of content creation. -[[中文文档]](/docs/zh_CN/README.md) [[潞晨云部署视频教程]](https://www.bilibili.com/video/BV141421R7Ag) +[[中文文档](/docs/zh_CN/README.md)] [[潞晨云](https://cloud.luchentech.com/)|[OpenSora镜像](https://cloud.luchentech.com/doc/docs/image/open-sora/)|[视频教程](https://www.bilibili.com/video/BV1ow4m1e7PX/?vd_source=c6b752764cd36ff0e535a768e35d98d2)] ## 📰 News @@ -38,8 +38,7 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi ## 🎥 Latest Demo -🔥 You can experience Open-Sora on our [🤗 Gradio application on Hugging Face](https://huggingface.co/spaces/hpcai-tech/open-sora). More samples are available in our [Gallery](https://hpcaitech.github.io/Open-Sora/). - +🔥 You can experience Open-Sora on our [🤗 Gradio application on Hugging Face](https://huggingface.co/spaces/hpcai-tech/open-sora). More samples and corresponding prompts are available in our [Gallery](https://hpcaitech.github.io/Open-Sora/). | **4s 720×1280** | **4s 720×1280** | **4s 720×1280** | | ---------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | @@ -47,7 +46,6 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/644bf938-96ce-44aa-b797-b3c0b513d64c) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/272d88ac-4b4a-484d-a665-8d07431671d0) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/ebbac621-c34e-4bb4-9543-1c34f8989764) | | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/a1e3a1a3-4abd-45f5-8df2-6cced69da4ca) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/d6ce9c13-28e1-4dff-9644-cc01f5f11926) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/561978f8-f1b0-4f4d-ae7b-45bec9001b4a) | -
OpenSora 1.1 Demo From e92672bdae1de7a819e4210cab7b9e1ecbd8906f Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Sat, 22 Jun 2024 13:26:55 +0000 Subject: [PATCH 09/12] handle av error --- opensora/datasets/read_video.py | 59 ++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/opensora/datasets/read_video.py b/opensora/datasets/read_video.py index ce88f593..6b5da346 100644 --- a/opensora/datasets/read_video.py +++ b/opensora/datasets/read_video.py @@ -89,18 +89,22 @@ def read_video_av( video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8) # == read == - # TODO: The reading has memory leak (4G for 8 workers 1 GPU) - container = av.open(filename, metadata_errors="ignore") - assert container.streams.video is not None - video_frames = _read_from_stream( - video_frames, - container, - start_pts, - end_pts, - pts_unit, - container.streams.video[0], - {"video": 0}, - ) + try: + # TODO: The reading has memory leak (4G for 8 workers 1 GPU) + container = av.open(filename, metadata_errors="ignore") + assert container.streams.video is not None + video_frames = _read_from_stream( + video_frames, + container, + start_pts, + end_pts, + pts_unit, + container.streams.video[0], + {"video": 0}, + filename=filename, + ) + except av.AVError as e: + print(f"[Warning] Error while reading video {filename}: {e}") vframes = torch.from_numpy(video_frames).clone() del video_frames @@ -120,6 +124,7 @@ def _read_from_stream( pts_unit: str, stream: "av.stream.Stream", stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], + filename: Optional[str] = None, ) -> List["av.frame.Frame"]: if pts_unit == "sec": @@ -159,26 +164,28 @@ def _read_from_stream( try: # TODO check if stream needs to always be the video stream here or not container.seek(seek_offset, any_frame=False, backward=True, stream=stream) - except av.AVError: - # TODO add some warnings in this case - # print("Corrupted file?", container.name) + except av.AVError as e: + print(f"[Warning] Error while seeking video {filename}: {e}") return [] # == main == buffer_count = 0 frames_pts = [] cnt = 0 - for _idx, frame in enumerate(container.decode(**stream_name)): - frames_pts.append(frame.pts) - video_frames[cnt] = frame.to_rgb().to_ndarray() - cnt += 1 - if cnt >= len(video_frames): - break - if frame.pts >= end_offset: - if should_buffer and buffer_count < max_buffer_size: - buffer_count += 1 - continue - break + try: + for _idx, frame in enumerate(container.decode(**stream_name)): + frames_pts.append(frame.pts) + video_frames[cnt] = frame.to_rgb().to_ndarray() + cnt += 1 + if cnt >= len(video_frames): + break + if frame.pts >= end_offset: + if should_buffer and buffer_count < max_buffer_size: + buffer_count += 1 + continue + break + except av.AVError as e: + print(f"[Warning] Error while reading video {filename}: {e}") # garbage collection for thread leakage container.close() From b3f7df82399b27a2128cdf423e933ccf6680df74 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Sat, 22 Jun 2024 15:54:27 +0000 Subject: [PATCH 10/12] [fix] better support local ckpt --- opensora/models/stdit/stdit3.py | 5 +++-- opensora/models/vae/vae.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/opensora/models/stdit/stdit3.py b/opensora/models/stdit/stdit3.py index 8703b2d1..bd9672db 100644 --- a/opensora/models/stdit/stdit3.py +++ b/opensora/models/stdit/stdit3.py @@ -448,7 +448,7 @@ def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w): @MODELS.register_module("STDiT3-XL/2") def STDiT3_XL_2(from_pretrained=None, **kwargs): force_huggingface = kwargs.pop("force_huggingface", False) - if force_huggingface or from_pretrained is not None and not os.path.isdir(from_pretrained): + if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained): model = STDiT3.from_pretrained(from_pretrained, **kwargs) else: config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) @@ -460,7 +460,8 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs): @MODELS.register_module("STDiT3-3B/2") def STDiT3_3B_2(from_pretrained=None, **kwargs): - if from_pretrained is not None and not os.path.isdir(from_pretrained): + force_huggingface = kwargs.pop("force_huggingface", False) + if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained): model = STDiT3.from_pretrained(from_pretrained, **kwargs) else: config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs) diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index bf50ec83..9802b02d 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -277,7 +277,7 @@ def OpenSoraVAE_V1_2( scale=scale, ) - if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)): + if force_huggingface or (from_pretrained is not None and not os.path.exists(from_pretrained)): model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs) else: config = VideoAutoencoderPipelineConfig(**kwargs) From 0312a0d85290eeb5c04f122c3688baaa0f8d9c8c Mon Sep 17 00:00:00 2001 From: Jiacheng Yang Date: Mon, 24 Jun 2024 05:07:49 -0400 Subject: [PATCH 11/12] fix SeqParallelMultiHeadCrossAttention for consistent results in distributed mode (#510) --- opensora/models/layers/blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 8bc7e720..5e2c13da 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -499,7 +499,7 @@ def forward(self, x, cond, mask=None): # shape: # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] - q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim) kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down") k, v = kv.unbind(2) From d2d6dd9485c3cb193ab053e14b8735eb90ce990f Mon Sep 17 00:00:00 2001 From: guangxiangyang <1327140521@qq.com> Date: Tue, 25 Jun 2024 09:57:33 +0800 Subject: [PATCH 12/12] Update README.md --- gradio/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradio/README.md b/gradio/README.md index 5dddd470..aee7303a 100644 --- a/gradio/README.md +++ b/gradio/README.md @@ -1,4 +1,4 @@ ---- +gaungxiangyang--- title: Open Sora emoji: 🎥 colorFrom: red