-
Notifications
You must be signed in to change notification settings - Fork 199
/
Copy pathrun_parallel.py
executable file
·46 lines (37 loc) · 1.72 KB
/
run_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from stepvideo.diffusion.video_pipeline import StepVideoPipeline
import torch.distributed as dist
import torch
from stepvideo.config import parse_args
from stepvideo.parallel import initialize_parall_group, get_parallel_group
from stepvideo.utils import setup_seed
from xfuser.model_executor.models.customized.step_video_t2v.tp_applicator import TensorParallelApplicator
from xfuser.core.distributed.parallel_state import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
if __name__ == "__main__":
args = parse_args()
initialize_parall_group(ring_degree=args.ring_degree, ulysses_degree=args.ulysses_degree, tensor_parallel_degree=args.tensor_parallel_degree)
local_rank = get_parallel_group().local_rank
device = torch.device(f"cuda:{local_rank}")
setup_seed(args.seed)
pipeline = StepVideoPipeline.from_pretrained(args.model_dir).to(dtype=torch.bfloat16, device="cpu")
if args.tensor_parallel_degree > 1:
tp_applicator = TensorParallelApplicator(get_tensor_model_parallel_world_size(), get_tensor_model_parallel_rank())
tp_applicator.apply_to_model(pipeline.transformer)
pipeline.transformer = pipeline.transformer.to(device)
pipeline.setup_api(
vae_url = args.vae_url,
caption_url = args.caption_url,
)
prompt = args.prompt
videos = pipeline(
prompt=prompt,
num_frames=args.num_frames,
height=args.height,
width=args.width,
num_inference_steps = args.infer_steps,
guidance_scale=args.cfg_scale,
time_shift=args.time_shift,
pos_magic=args.pos_magic,
neg_magic=args.neg_magic,
output_file_name=prompt[:50]
)
dist.destroy_process_group()