diff --git a/.gitignore b/.gitignore index 96c0f1a..00ebf5a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ module_articulate/models/*.pth module_spline/models/*.tar module_mrfa/models/*.pth +module_fsrt/models/*.pt checkpoints/* !checkpoints/place_checkpoints_here.txt diff --git a/README.md b/README.md index 307d15f..f9b7b76 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,9 @@ Now supports: 2. [Motion Representations for Articulated Animation](https://github.com/snap-research/articulated-animation) 3. [Thin-Plate Spline Motion Model for Image Animation](https://github.com/yoyo-nb/thin-plate-spline-motion-model) 4. [Learning Motion Refinement for Unsupervised Face Animation](https://github.com/JialeTao/MRFA/) +5. [Facial Scene Representation Transformer for Face Reenactment](https://github.com/andrerochow/fsrt) -More will come soon - -https://github.com/user-attachments/assets/b2948efb-3b44-440b-bff2-dde7b95a9946 +https://github.com/user-attachments/assets/b090061d-8f12-42c4-b046-d8b0e0a69685 ## Workflow: @@ -46,6 +45,12 @@ https://github.com/user-attachments/assets/b2948efb-3b44-440b-bff2-dde7b95a9946 ![Workflow MRFA](workflows/workflow_mrfa.png) +### FSRT + +[FSRT.json](workflows/FSRT.json) + +![Workflow FSRT](workflows/workflow_fsrt.png) + ## Arguments ### FOMM @@ -78,10 +83,21 @@ Doesn't need any ### MRFA -* `model_name`: `celebvhq` or `vox`, which is trained on different datasets +* `model_name`: `vox` or `celebvhq`, which is trained on (presumably) the `vox256` and `celebhq` datasets respectively. * `use_relative`: Whether to use relative mode or not (absolute mode). Absolute mode is similar to FOMM's `adapt_movement_scale` set to False * `relative_movement`, `relative_jacobian`, `adapt_movement_scale`: Same as FOMM +### FSRT + +This model takes the longest to run. The full Damedane example takes ~6 minutes + +* `model_name`: `vox256` or `vox256_2Source`, which is trained on (presumably) the `vox256` and `vox256+celebhq` datasets respectively. +* `use_relative`: Use relative or absolute keypoint coordinates +* `adapt_scale`: Adapt movement scale based on convex hull of keypoints +* `find_best_frame`: Same as FOMM +* `max_num_pixels`: Number of parallel processed pixels. Reduce this value if you run out of GPU memory + + ## Installation 1. Clone the repo to `ComfyUI/custom_nodes/` @@ -150,9 +166,13 @@ resnet18-5c106cde.pth | **Spline** | `module_articulate/models/vox.pth.tar` | [Thin Plate Spline Motion Model (Pre-trained models)](https://github.com/yoyo-nb/thin-plate-spline-motion-model?tab=readme-ov-file#pre-trained-models) | | **MRFA** (celebvhq) | `module_mrfa/models/celebvhq.pth` | [MRFA (Pre-trained checkpoints)](https://github.com/JialeTao/MRFA/?tab=readme-ov-file#pretrained-models) | | **MRFA** (vox) | `module_mrfa/models/vox.pth` | [MRFA (Pre-trained checkpoints)](https://github.com/JialeTao/MRFA/?tab=readme-ov-file#pretrained-models) | +| **FSRT** (kp_detector) | `module_fsrt/models/kp_detector.pt` | [FSRT (Pretrained Checkpoints)](https://github.com/andrerochow/fsrt?tab=readme-ov-file#pretrained-checkpoints) | +| **FSRT** (vox256) | `module_fsrt/models/vox256.pt` | [FSRT (Pretrained Checkpoints)](https://github.com/andrerochow/fsrt?tab=readme-ov-file#pretrained-checkpoints) | +| **FSRT** (vox256_2Source) | `module_fsrt/models/vox256_2Source.pt` | [FSRT (Pretrained Checkpoints)](https://github.com/andrerochow/fsrt?tab=readme-ov-file#pretrained-checkpoints) | Notes: -- For **Spline**, to use `find_best_frame`, follow above instructions to install `face-alignment` with its models. +- For **Spline** and **FSRT**, to use `find_best_frame`, follow above instructions to install `face-alignment` with its models. +- For **FSRT**, you must download `kp_detector` ## Credits @@ -195,3 +215,12 @@ year={2023}, url={https://openreview.net/forum?id=m9uHv1Pxq7} } ``` + +``` +@inproceedings{rochow2024fsrt, + title={{FSRT}: Facial Scene Representation Transformer for Face Reenactment from Factorized Appearance, Head-pose, and Facial Expression Features}, + author={Rochow, Andre and Schwarz, Max and Behnke, Sven}, + booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2024} +} +``` diff --git a/__init__.py b/__init__.py index f59a68f..544d4f6 100644 --- a/__init__.py +++ b/__init__.py @@ -6,7 +6,8 @@ FOMM_Seg15Chooser, Articulate_Runner, Spline_Runner, - MRFA_Runner + MRFA_Runner, + FSRT_Runner ) NODE_CLASS_MAPPINGS = { @@ -18,6 +19,7 @@ "Articulate_Runner": Articulate_Runner, "Spline_Runner": Spline_Runner, "MRFA_Runner": MRFA_Runner, + "FSRT_Runner": FSRT_Runner, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -29,6 +31,7 @@ "Articulate_Runner": "Articulate Runner", "Spline_Runner": "Spline Runner", "MRFA_Runner": "MRFA Runner", + "FSRT_Runner": "FSRT Runner", } diff --git a/constants.py b/constants.py index ad76d08..4221e98 100644 --- a/constants.py +++ b/constants.py @@ -102,20 +102,37 @@ ARTICULATE_MODEL_PATH = "module_articulate/models/vox256.pth" ARTICULATE_CFG_PATH = "module_articulate/config/vox256.yaml" -SPLINE_MODES = ['relative', 'standard', 'avd'] -SPLINE_DEFAULT = 'relative' +SPLINE_MODES = ["relative", "standard", "avd"] +SPLINE_DEFAULT = "relative" SPLINE_MODEL_PATH = "module_spline/models/vox.pth.tar" SPLINE_CFG_PATH = "module_spline/config/vox-256.yaml" -MRFA_MODEL_NAMES = ["celebvhq", "vox"] +MRFA_MODEL_NAMES = ["vox", "celebvhq"] MRFA_MODEL_PATHS = { - "celebvhq": "module_mrfa/models/celebvhq.pth", "vox": "module_mrfa/models/vox.pth", + "celebvhq": "module_mrfa/models/celebvhq.pth", } -MRFA_DEFAULT_MODEL = "celebvhq" +MRFA_DEFAULT_MODEL = "vox" MRFA_CFG_PATHS = { "celebvhq": "module_mrfa/configs/celebvhq.yaml", "vox": "module_mrfa/configs/vox1.yaml", -} \ No newline at end of file +} + +FSRT_MODEL_NAMES = [ + "vox256", "vox256_2Source" +] + +FSRT_DEFAULT_MODEL = "vox256" + +FSRT_MODEL_PATHS = { + "vox256": "module_fsrt/models/vox256.pt", + "vox256_2Source": "module_fsrt/models/vox256_2Source.pt", +} + +FSRT_KP_PATH = "module_fsrt/models/kp_detector.pt" +FSRT_CFG_PATHS = { + "vox256": "module_fsrt/configs/vox256.yaml", + "vox256_2Source": "module_fsrt/configs/vox256_2Source.yaml", +} diff --git a/inference_fsrt.py b/inference_fsrt.py new file mode 100644 index 0000000..738ab07 --- /dev/null +++ b/inference_fsrt.py @@ -0,0 +1,311 @@ +import numpy as np +import torch +import tqdm +import yaml +from comfy.utils import ProgressBar +from scipy.spatial import ConvexHull + +from .inference_fomm import find_best_frame +from .module_fsrt.checkpoint import Checkpoint +from .module_fsrt.expression_encoder import ExpressionEncoder +from .module_fsrt.keypoint_detector import KPDetector +from .module_fsrt.model import FSRT + + +def fsrt_inference( + source_image, + driving_video: list, + config_path: str, + checkpoint_path: str, + keypoint_path: str, + relative=False, # use relative or absolute keypoint coordinates + adapt_scale=False, # adapt movement scale based on convex hull of keypoints + find_best_frame=False, # Generate from the frame that is the most alligned with source + max_num_pixels=65536, # Number of parallel processed pixels. Reduce this value if you run out of GPU memory! +): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + with open(config_path) as f: + cfg = yaml.full_load(f) + + kp_detector = KPDetector().to(device) + kp_detector.load_state_dict(torch.load(keypoint_path)) + expression_encoder = ExpressionEncoder( + expression_size=cfg["model"]["expression_size"], + in_channels=kp_detector.predictor.out_filters, + ) + + model = FSRT(cfg["model"], expression_encoder=expression_encoder).to(device) + + model.eval() + kp_detector.eval() + + encoder_module = model.encoder + decoder_module = model.decoder + expression_encoder_module = model.expression_encoder + + checkpoint = Checkpoint( + "./", + device=device, + encoder=encoder_module, + decoder=decoder_module, + expression_encoder=expression_encoder_module, + ) + + _ = checkpoint.load(checkpoint_path) + + source_image = source_image.to(device) + if find_best_frame: + predictions = inference_best_frame( + source_image, + driving_video, + model, + kp_detector, + cfg, + device, + max_num_pixels, + relative=relative, + adapt_movement_scale=adapt_scale, + ) + else: + predictions = inference( + source_image, + driving_video, + model, + kp_detector, + cfg, + device, + max_num_pixels, + relative=relative, + adapt_movement_scale=adapt_scale, + ) + + return predictions + + +def inference_best_frame( + source_image, + driving_video, + model, + kp_detector, + cfg, + device, + max_num_pixels, + relative=False, + adapt_movement_scale=False, +): + best_frame_idx = find_best_frame(source_image, driving_video) + + first_half = driving_video[ + :, :, : best_frame_idx + 1 + ] # Include the best frame in the first half + first_half = torch.flip(first_half, dims=[2]) # Reverse the first half + second_half = driving_video[:, :, best_frame_idx + 1 :] + + predictions_first = inference( + source_image, + first_half, + model, + kp_detector, + cfg, + device, + max_num_pixels, + relative, + adapt_movement_scale, + ) + predictions_second = inference( + source_image, + second_half, + model, + kp_detector, + cfg, + device, + max_num_pixels, + relative, + adapt_movement_scale, + ) + + predictions = [] + predictions_first = predictions_first[::-1] # Reverse the first half back + predictions.extend(predictions_first) + predictions.extend(predictions_second) + + return predictions + + +def inference( + source_image, + driving_video, + model, + kp_detector, + cfg, + device, + max_num_pixels, + relative=False, + adapt_movement_scale=False, +): + source_image = source_image.permute(0, 2, 3, 1) + + _, y, x = np.meshgrid( + np.zeros(2), + np.arange(source_image.shape[-3]), + np.arange(source_image.shape[-2]), + indexing="ij", + ) + idx_grids = np.stack([x, y], axis=-1).astype(np.float32) + # Normalize + idx_grids[..., 0] = (idx_grids[..., 0] + 0.5 - ((source_image.shape[-3]) / 2.0)) / ( + (source_image.shape[-3]) / 2.0 + ) + idx_grids[..., 1] = (idx_grids[..., 1] + 0.5 - ((source_image.shape[-2]) / 2.0)) / ( + (source_image.shape[-2]) / 2.0 + ) + idx_grids = torch.from_numpy(idx_grids).to(device).unsqueeze(0) + z = None + + with torch.no_grad(): + predictions = [] + source = source_image.permute(0, 3, 1, 2) + # driving = driving_video.permute(0, 4, 1, 2, 3) + driving = driving_video + # print(f"{source.shape=}") + # print(f"{driving_video.shape=}") + kp_source, expression_vector_src = extract_keypoints_and_expression( + source.clone(), model, kp_detector, cfg, src=True + ) + kp_driving_initial, _ = extract_keypoints_and_expression( + driving[:, :, 0].to(device).clone(), model, kp_detector, cfg + ) + + num_frames = driving.shape[2] + pbar = ProgressBar(num_frames) + for frame_idx in tqdm.tqdm(range(num_frames), desc="Generating"): + driving_frame = driving[:, :, frame_idx].to(device) + kp_driving, expression_vector_driv = extract_keypoints_and_expression( + driving_frame.clone(), model, kp_detector, cfg + ) + kp_norm = normalize_kp( + kp_source=kp_source[0], + kp_driving=kp_driving, + kp_driving_initial=kp_driving_initial, + use_relative_movement=relative, + adapt_movement_scale=adapt_movement_scale, + ) + out, z = forward_model( + model, + expression_vector_src, + kp_source, + expression_vector_driv, + kp_norm, + source.unsqueeze(0), + idx_grids, + cfg, + max_num_pixels, + z=z, + ) + + pred = torch.clamp(out[0], 0.0, 1.0) + predictions.append(pred.unsqueeze(0)) + pbar.update_absolute(frame_idx, num_frames) + + return predictions + + +def forward_model( + model, + expression_vector_src, + keypoints_src, + expression_vector_driv, + keypoints_driv, + img_src, + idx_grids, + cfg, + max_num_pixels, + z=None, +): + # render_kwargs = cfg["model"]["decoder_kwargs"] + if len(img_src.shape) < 5: + img_src = img_src.unsqueeze(1) + if len(keypoints_src.shape) < 4: + keypoints_src = keypoints_src.unsqueeze(1) + + if z is None: + z = model.encoder( + img_src, + keypoints_src, + idx_grids[:, :1].repeat(1, img_src.shape[1], 1, 1, 1), + expression_vector=expression_vector_src, + ) + + target_pos = idx_grids[:, 1] + target_kps = keypoints_driv + + _, height, width = target_pos.shape[:3] + target_pos = target_pos.flatten(1, 2) + + target_kps = target_kps.unsqueeze(1).repeat(1, target_pos.shape[1], 1, 1) + + num_pixels = target_pos.shape[1] + img = torch.zeros((target_pos.shape[0], target_pos.shape[1], 3)) + + for i in range(0, num_pixels, max_num_pixels): + img[:, i : i + max_num_pixels], extras = model.decoder( + z.clone(), + target_pos[:, i : i + max_num_pixels], + target_kps[:, i : i + max_num_pixels], + expression_vector=expression_vector_driv, + ) + + return img.view(img.shape[0], height, width, 3), z + + +def normalize_kp( + kp_source, + kp_driving, + kp_driving_initial, + adapt_movement_scale=False, + use_relative_movement=False, +): + if adapt_movement_scale: + source_area = ConvexHull(kp_source.data.cpu().numpy()).volume + driving_area = ConvexHull(kp_driving_initial[0].data.cpu().numpy()).volume + adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) + else: + adapt_movement_scale = 1 + + kp_new = kp_driving + + if use_relative_movement: + kp_value_diff = kp_driving - kp_driving_initial + kp_value_diff *= adapt_movement_scale + kp_new = kp_value_diff + kp_source + + return kp_new + + +def extract_keypoints_and_expression(img, model, kp_detector, cfg, src=False): + assert kp_detector is not None + + bs, c, h, w = img.shape + nkp = kp_detector.num_kp + with torch.no_grad(): + kps, latent_dict = kp_detector(img) + heatmaps = latent_dict["heatmap"].view( + bs, nkp, latent_dict["heatmap"].shape[-2], latent_dict["heatmap"].shape[-1] + ) + feature_maps = latent_dict["feature_map"].view( + bs, + latent_dict["feature_map"].shape[-3], + latent_dict["feature_map"].shape[-2], + latent_dict["feature_map"].shape[-1], + ) + + if kps.shape[1] == 1: + kps = kps.squeeze(1) + + expression_vector = model.expression_encoder(feature_maps, heatmaps) + + if src: + expression_vector = expression_vector[None] + + return kps, expression_vector diff --git a/module_fsrt/checkpoint.py b/module_fsrt/checkpoint.py new file mode 100644 index 0000000..bfac449 --- /dev/null +++ b/module_fsrt/checkpoint.py @@ -0,0 +1,65 @@ +import os + +import torch + + +class Checkpoint: + """ + Handles saving and loading checkpoints. + + Args: + checkpoint_dir (str): path where checkpoints are saved + device: PyTorch device onto which loaded weights should be mapped + kwargs: PyTorch modules whose state should be checkpointed + """ + + def __init__(self, checkpoint_dir="./chkpts", device=None, **kwargs): + self.module_dict = kwargs + self.device = device + # self.checkpoint_dir = checkpoint_dir + # if not os.path.exists(checkpoint_dir): + # os.makedirs(checkpoint_dir) + + # def save(self, filename, **kwargs): + # """Saves the current module states + # Args: + # filename (str): name of checkpoint file + # kwargs: Additional state to save + # """ + # if not os.path.isabs(filename): + # filename = os.path.join(self.checkpoint_dir, filename) + + # outdict = kwargs + # for k, v in self.module_dict.items(): + # if k in outdict: + # print( + # f"Warning: Checkpoint key '{k}' overloaded. Defaulting to saving state_dict {v}." + # ) + # if v is not None: + # outdict[k] = v.state_dict() + # torch.save(outdict, filename) + + def load(self, filename): + """Loads a checkpoint from file. + Args: + filename (str): Name of checkpoint file. + Returns: + Dictionary containing checkpoint data which does not correspond to registered modules. + """ + + # if not os.path.isabs(filename): + # filename = os.path.join(self.checkpoint_dir, filename) + + print(f"Loading checkpoint from file {filename}...") + state_dict = torch.load(filename, map_location=self.device) + + for k, v in self.module_dict.items(): + if k in state_dict: + v.load_state_dict(state_dict[k]) + else: + print(f'Warning: Could not find "{k}" in checkpoint!') + + remaining_state = { + k: v for k, v in state_dict.items() if k not in self.module_dict + } + return remaining_state diff --git a/module_fsrt/configs/vox256.yaml b/module_fsrt/configs/vox256.yaml new file mode 100644 index 0000000..67c9c2e --- /dev/null +++ b/module_fsrt/configs/vox256.yaml @@ -0,0 +1,17 @@ +data: + num_src: 1 #Model is trained with num_src source images + +model: + encoder_kwargs: + pix_octaves: 16 + pix_start_octave: -1 + kp_octaves: 4 + kp_start_octave: -1 + encode_with_expression: True #If True, the expression vector is used to encode the source image. + decoder_kwargs: + pix_start_octave: -1 + pix_octaves: 16 + kp_octaves: 4 + kp_start_octave: -1 + small_decoder: False + expression_size: 256 diff --git a/module_fsrt/configs/vox256_2Source.yaml b/module_fsrt/configs/vox256_2Source.yaml new file mode 100644 index 0000000..3e9cc9d --- /dev/null +++ b/module_fsrt/configs/vox256_2Source.yaml @@ -0,0 +1,17 @@ +data: + num_src: 2 #Model is trained with num_src source images + +model: + encoder_kwargs: + pix_octaves: 16 + pix_start_octave: -1 + kp_octaves: 4 + kp_start_octave: -1 + encode_with_expression: True #If True, the expression vector is used to encode the source image. + decoder_kwargs: + pix_start_octave: -1 + pix_octaves: 16 + kp_octaves: 4 + kp_start_octave: -1 + small_decoder: False + expression_size: 256 diff --git a/module_fsrt/decoder.py b/module_fsrt/decoder.py new file mode 100644 index 0000000..b4a2131 --- /dev/null +++ b/module_fsrt/decoder.py @@ -0,0 +1,133 @@ +import numpy as np +import torch +import torch.nn as nn + +from .layers import FSRTPosEncoder, Transformer + + +class FSRTPixelPredictor(nn.Module): + def __init__( + self, + num_att_blocks=2, + pix_octaves=16, + pix_start_octave=-1, + out_dims=3, + z_dim=768, + input_mlp=True, + output_mlp=False, + num_kp=10, + expression_size=0, + kp_octaves=4, + kp_start_octave=-1, + ): + super().__init__() + + self.positional_encoder = FSRTPosEncoder( + kp_octaves=kp_octaves, + kp_start_octave=kp_start_octave, + pix_octaves=pix_octaves, + pix_start_octave=pix_start_octave, + ) + self.expression_size = expression_size + self.num_kp = num_kp + self.feat_dim = pix_octaves * 4 + num_kp * kp_octaves * 4 + self.expression_size + + if input_mlp: # Input MLP added with OSRT improvements + self.input_mlp = nn.Sequential( + nn.Linear(self.feat_dim, 720), nn.ReLU(), nn.Linear(720, self.feat_dim) + ) + else: + self.input_mlp = None + + self.transformer = Transformer( + self.feat_dim, + depth=num_att_blocks, + heads=12, + dim_head=z_dim // 12, + mlp_dim=z_dim * 2, + selfatt=False, + kv_dim=z_dim, + ) + + if output_mlp: + self.output_mlp = nn.Sequential( + nn.Linear(self.feat_dim, 128), nn.ReLU(), nn.Linear(128, out_dims) + ) + else: + self.output_mlp = None + + def forward(self, z, pixels, keypoints, expression_vector=None): + """ + Args: + z: set-latent scene repres. [batch_size, num_patches, patch_dim] + pixels: query pixels [batch_size, num_pixels, 2] + keypoints: facial query keypoints [batch_size, num_pixels, num_kp, 2] + expression_vector: latent repres. of the query expression [batch_size, expression_size] + """ + bs = pixels.shape[0] + nr = pixels.shape[1] + nkp = keypoints.shape[-2] + queries = self.positional_encoder(pixels, keypoints.view(bs, nr, nkp * 2)) + + if expression_vector is not None: + queries = torch.cat( + [queries, expression_vector[:, None].repeat(1, queries.shape[1], 1)], + dim=-1, + ) + + if self.input_mlp is not None: + queries = self.input_mlp(queries) + + output = self.transformer(queries, z) + + if self.output_mlp is not None: + output = self.output_mlp(output) + + return output + + +class ImprovedFSRTDecoder(nn.Module): + """Scene Representation Transformer Decoder with the improvements from Appendix A.4 in the OSRT paper.""" + + def __init__( + self, + num_att_blocks=2, + pix_octaves=16, + pix_start_octave=-1, + num_kp=10, + kp_octaves=4, + kp_start_octave=-1, + expression_size=0, + ): + super().__init__() + self.allocation_transformer = FSRTPixelPredictor( + num_att_blocks=num_att_blocks, + pix_start_octave=pix_start_octave, + pix_octaves=pix_octaves, + z_dim=768, + input_mlp=True, + output_mlp=False, + expression_size=expression_size, + kp_octaves=kp_octaves, + kp_start_octave=kp_start_octave, + ) + self.expression_size = expression_size + self.feat_dim = pix_octaves * 4 + num_kp * kp_octaves * 4 + self.expression_size + self.render_mlp = nn.Sequential( + nn.Linear(self.feat_dim, 1536), + nn.ReLU(), + nn.Linear(1536, 1536), + nn.ReLU(), + nn.Linear(1536, 1536), + nn.ReLU(), + nn.Linear(1536, 1536), + nn.ReLU(), + nn.Linear(1536, 3), + ) + + def forward(self, z, x, pixels, expression_vector=None): + x = self.allocation_transformer( + z, x, pixels, expression_vector=expression_vector + ) + pixels = self.render_mlp(x) + return pixels, {} diff --git a/module_fsrt/encoder.py b/module_fsrt/encoder.py new file mode 100644 index 0000000..b64a953 --- /dev/null +++ b/module_fsrt/encoder.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn + +from .layers import FSRTPosEncoder, Transformer + + +class SRTConvBlock(nn.Module): + def __init__(self, idim, hdim=None, odim=None): + super().__init__() + if hdim is None: + hdim = idim + + if odim is None: + odim = 2 * hdim + + conv_kwargs = {"bias": False, "kernel_size": 3, "padding": 1} + self.layers = nn.Sequential( + nn.Conv2d(idim, hdim, stride=1, **conv_kwargs), + nn.ReLU(), + nn.Conv2d(hdim, odim, stride=2, **conv_kwargs), + nn.ReLU(), + ) + + def forward(self, x): + return self.layers(x) + + +class ImprovedFSRTEncoder(nn.Module): + """ + Scene Representation Transformer Encoder with the improvements from Appendix A.4 in the OSRT paper. + """ + + def __init__( + self, + num_conv_blocks=3, + num_att_blocks=5, + pix_octaves=16, + pix_start_octave=-1, + num_kp=10, + expression_size=256, + encode_with_expression=True, + kp_octaves=4, + kp_start_octave=-1, + ): + super().__init__() + self.positional_encoder = FSRTPosEncoder( + kp_octaves=kp_octaves, + kp_start_octave=kp_start_octave, + pix_octaves=pix_octaves, + pix_start_octave=pix_start_octave, + ) + + self.encode_with_expression = encode_with_expression + if self.encode_with_expression: + self.expression_size = expression_size + else: + self.expression_size = 0 + conv_blocks = [ + SRTConvBlock( + idim=3 + + pix_octaves * 4 + + num_kp * kp_octaves * 4 + + self.expression_size, + hdim=96, + ) + ] + cur_hdim = 192 + for i in range(1, num_conv_blocks): + conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) + cur_hdim *= 2 + + self.conv_blocks = nn.Sequential(*conv_blocks) + + self.per_patch_linear = nn.Conv2d(cur_hdim, 768, kernel_size=1) + + self.transformer = Transformer( + 768, depth=num_att_blocks, heads=12, dim_head=64, mlp_dim=1536, selfatt=True + ) + self.num_kp = num_kp + + def forward(self, images, keypoints, pixels, expression_vector=None): + """ + Args: + images: [batch_size, num_images, 3, height, width] + keypoints: [batch_size, num_images, num_kp, 2] + pixels: [batch_size, num_images, height, width, 2] + expression_vector: [batch_size, num_images, expression_size] + Returns: + scene representation: [batch_size, num_patches, channels_per_patch] + """ + + batch_size, num_images = images.shape[:2] + + x = images.flatten(0, 1) + keypoints = keypoints.flatten(-2, -1).flatten(0, 1) + pixels = pixels.flatten(0, 1) + + pos_enc = self.positional_encoder(pixels, keypoints) + if expression_vector is not None and self.encode_with_expression: + expression_vector = expression_vector.flatten(0, 1)[ + :, :, None, None + ].repeat(1, 1, images.shape[-2], images.shape[-1]) + x = torch.cat([x, pos_enc, expression_vector], 1) + else: + x = torch.cat([x, pos_enc], 1) + x = self.conv_blocks(x) + x = self.per_patch_linear(x) + x = x.flatten(2, 3).permute(0, 2, 1) + + patches_per_image, channels_per_patch = x.shape[1:] + x = x.reshape(batch_size, num_images * patches_per_image, channels_per_patch) + + x = self.transformer(x) + + return x diff --git a/module_fsrt/expression_encoder.py b/module_fsrt/expression_encoder.py new file mode 100644 index 0000000..36dba86 --- /dev/null +++ b/module_fsrt/expression_encoder.py @@ -0,0 +1,57 @@ +from torch import nn + + +class ExpressionEncoder(nn.Module): + """ + Extracts the latent expression vector. + """ + + def __init__( + self, + in_channels=32, + num_kp=10, + expression_size_per_kp=32, + expression_size=256, + pad=0, + ): + super(ExpressionEncoder, self).__init__() + + self.expression_size = expression_size # Output dimension + self.expression_size_per_kp = expression_size_per_kp # Number of output features of the convolutional layer for each keypoint + self.num_kp = num_kp + self.expression = nn.Conv2d( + in_channels=in_channels, + out_channels=num_kp * self.expression_size_per_kp, + kernel_size=(7, 7), + padding=pad, + ) + self.expression_mlp = nn.Sequential( + nn.Linear(self.expression_size_per_kp * self.num_kp, 640), + nn.ReLU(), + nn.Linear(640, 1280), + nn.ReLU(), + nn.Linear(1280, 640), + nn.ReLU(), + nn.Linear(640, self.expression_size), + ) + + def forward(self, feature_map, heatmap): + latent_expression_feat = self.expression(feature_map) + final_shape = latent_expression_feat.shape + latent_expression_feat = latent_expression_feat.reshape( + final_shape[0], + self.num_kp, + self.expression_size_per_kp, + final_shape[2], + final_shape[3], + ) + + heatmap = heatmap.unsqueeze(2) + latent_expression = heatmap * latent_expression_feat + latent_expression = latent_expression.view( + final_shape[0], self.num_kp, self.expression_size_per_kp, -1 + ) + latent_expression = latent_expression.sum(dim=-1).view(final_shape[0], -1) + latent_expression = self.expression_mlp(latent_expression) + + return latent_expression diff --git a/module_fsrt/keypoint_detector.py b/module_fsrt/keypoint_detector.py new file mode 100644 index 0000000..12add39 --- /dev/null +++ b/module_fsrt/keypoint_detector.py @@ -0,0 +1,308 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BatchNorm2d + + +def make_coordinate_grid(spatial_size, type): + """ + Create a meshgrid [-1,1] x [-1,1] of given spatial_size. + """ + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + + x = 2 * (x / (w - 1)) - 1 + y = 2 * (y / (h - 1)) - 1 + + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + + return meshed + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp + + coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2) + mean = mean.view(*shape) + + mean_sub = coordinate_grid - mean + + out = torch.exp(-0.5 * (mean_sub**2).sum(-1) / kp_variance) + + return out + + +class UpBlock2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock2d, self).__init__() + + self.conv = nn.Conv2d( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=groups, + ) + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=groups, + ) + self.norm = BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d( + in_features + if i == 0 + else min(max_features, block_expansion * (2**i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, + padding=1, + ) + ) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min( + max_features, block_expansion * (2 ** (i + 1)) + ) + out_filters = min(max_features, block_expansion * (2**i)) + up_blocks.append( + UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1) + ) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + + def forward(self, x): + out = x.pop() + cnt = 0 + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + cnt += 1 + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-((mgrid - mean) ** 2) / (2 * std**2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + self.scale = scale + inv_scale = 1 / scale + self.int_inv_scale = int(inv_scale) + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = out[:, :, :: self.int_inv_scale, :: self.int_inv_scale] + + return out + + +class KPDetector(nn.Module): + """ + Detecting keypoints. Return keypoint positions. + """ + + def __init__( + self, + block_expansion=32, + num_kp=10, + num_channels=3, + max_features=1024, + num_blocks=5, + temperature=0.1, + scale_factor=0.25, + pad=0, + ): + super(KPDetector, self).__init__() + + self.predictor = Hourglass( + block_expansion, + in_features=num_channels, + max_features=max_features, + num_blocks=num_blocks, + ) + + self.kp = nn.Conv2d( + in_channels=self.predictor.out_filters, + out_channels=num_kp, + kernel_size=(7, 7), + padding=pad, + ) + + self.num_kp = num_kp + + # We do not need the Jacobian (from FOMM). + # if estimate_jacobian: + # self.num_jacobian_maps = 1 if single_jacobian_map else num_kp + # self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, + # out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) + # self.jacobian.weight.data.zero_() + # self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) + # else: + # self.jacobian = None + + self.temperature = temperature + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + + def gaussian2kp(self, heatmap): + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) + grid = ( + make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) + ) + value = (heatmap * grid).sum(dim=(2, 3)) + return value + + def forward(self, x): + with torch.no_grad(): + if self.scale_factor != 1: + x = self.down(x) + + feature_map = self.predictor(x) + prediction = self.kp(feature_map) + + final_shape = prediction.shape + heatmap = prediction.view(final_shape[0], final_shape[1], -1) + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) + + out = self.gaussian2kp(heatmap) + heatmap = heatmap.unsqueeze(2) + + # We do not need the Jacobian (from FOMM). + # if self.jacobian is not None: + # jacobian_map = self.jacobian(feature_map) + # jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], + # final_shape[3]) + # jacobian = heatmap * jacobian_map + # jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + # jacobian = jacobian.sum(dim=-1) + # jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) + + return out, {"feature_map": feature_map, "heatmap": heatmap} diff --git a/module_fsrt/layers.py b/module_fsrt/layers.py new file mode 100644 index 0000000..8db000c --- /dev/null +++ b/module_fsrt/layers.py @@ -0,0 +1,184 @@ +import math + +import torch +import torch.nn as nn +from einops import rearrange + + +class PositionalEncoding(nn.Module): + def __init__(self, num_octaves, start_octave): + super().__init__() + self.num_octaves = num_octaves + self.start_octave = start_octave + + def forward(self, coords): + embed_fns = [] + batch_size, num_points, dim = coords.shape + + octaves = torch.arange(self.start_octave, self.start_octave + self.num_octaves) + octaves = octaves.float().to(coords) + multipliers = 2**octaves * math.pi + coords = coords.unsqueeze(-1) + while len(multipliers.shape) < len(coords.shape): + multipliers = multipliers.unsqueeze(0) + + scaled_coords = coords * multipliers + + sines = torch.sin(scaled_coords).reshape( + batch_size, num_points, dim * self.num_octaves + ) + cosines = torch.cos(scaled_coords).reshape( + batch_size, num_points, dim * self.num_octaves + ) + + result = torch.cat((sines, cosines), -1) + return result + + +class FSRTPosEncoder(nn.Module): + def __init__( + self, kp_octaves=4, kp_start_octave=-1, pix_start_octave=-1, pix_octaves=16 + ): + super().__init__() + self.kp_encoding = PositionalEncoding( + num_octaves=kp_octaves, start_octave=kp_start_octave + ) + self.pix_encoding = PositionalEncoding( + num_octaves=pix_octaves, start_octave=pix_start_octave + ) + + def forward(self, pixels, kps=None): + if len(pixels.shape) == 4: + batchsize, height, width, _ = pixels.shape + pixels = pixels.flatten(1, 2) + pix_enc = self.pix_encoding(pixels) + pix_enc = pix_enc.view(batchsize, height, width, pix_enc.shape[-1]) + pix_enc = pix_enc.permute((0, 3, 1, 2)) + + if kps is not None: + kp_enc = self.kp_encoding(kps.unsqueeze(1)) + kp_enc = kp_enc.view(batchsize, kp_enc.shape[-1], 1, 1).repeat( + 1, 1, height, width + ) + x = torch.cat((kp_enc, pix_enc), 1) + else: + pix_enc = self.pix_encoding(pixels) + + if kps is not None: + kp_enc = self.kp_encoding(kps) + x = torch.cat((kp_enc, pix_enc), -1) + + return x + + +# Transformer implementation based on ViT +# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__( + self, dim, heads=8, dim_head=64, dropout=0.0, selfatt=True, kv_dim=None + ): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + if selfatt: + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + else: + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(kv_dim, inner_dim * 2, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, z=None): + if z is None: + qkv = self.to_qkv(x).chunk(3, dim=-1) + else: + q = self.to_q(x) + k, v = self.to_kv(z).chunk(2, dim=-1) + qkv = (q, k, v) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__( + self, + dim, + depth, + heads, + dim_head, + mlp_dim, + dropout=0.0, + selfatt=True, + kv_dim=None, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PreNorm( + dim, + Attention( + dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + selfatt=selfatt, + kv_dim=kv_dim, + ), + ), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), + ] + ) + ) + + def forward(self, x, z=None): + for attn, ff in self.layers: + x = attn(x, z=z) + x + x = ff(x) + x + return x diff --git a/module_fsrt/model.py b/module_fsrt/model.py new file mode 100644 index 0000000..37209b7 --- /dev/null +++ b/module_fsrt/model.py @@ -0,0 +1,26 @@ +from torch import nn + +from .decoder import ImprovedFSRTDecoder +from .encoder import ImprovedFSRTEncoder +from .small_decoder import ImprovedFSRTDecoder as SmallImprovedFSRTDecoder + + +class FSRT(nn.Module): + def __init__(self, cfg, expression_encoder=None): + super().__init__() + + self.encoder = ImprovedFSRTEncoder( + expression_size=cfg["expression_size"], **cfg["encoder_kwargs"] + ) + + if cfg["small_decoder"]: + self.decoder = SmallImprovedFSRTDecoder( + expression_size=cfg["expression_size"], **cfg["decoder_kwargs"] + ) + print("Loading small decoder") + else: + self.decoder = ImprovedFSRTDecoder( + expression_size=cfg["expression_size"], **cfg["decoder_kwargs"] + ) + + self.expression_encoder = expression_encoder diff --git a/module_fsrt/models/.gitkeep b/module_fsrt/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/module_fsrt/small_decoder.py b/module_fsrt/small_decoder.py new file mode 100644 index 0000000..5247623 --- /dev/null +++ b/module_fsrt/small_decoder.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn + +from .layers import FSRTPosEncoder, Transformer + + +class FSRTPixelPredictor(nn.Module): + def __init__( + self, + num_att_blocks=2, + pix_octaves=16, + pix_start_octave=-1, + out_dims=3, + z_dim=768, + input_mlp=True, + output_mlp=False, + num_kp=10, + expression_size=0, + kp_octaves=4, + kp_start_octave=-1, + ): + super().__init__() + + self.positional_encoder = FSRTPosEncoder( + kp_octaves=kp_octaves, + kp_start_octave=kp_start_octave, + pix_octaves=pix_octaves, + pix_start_octave=pix_start_octave, + ) + self.expression_size = expression_size + self.num_kp = num_kp + self.feat_dim = pix_octaves * 4 + num_kp * kp_octaves * 4 + self.expression_size + + if input_mlp: # Input MLP added with OSRT improvements + self.input_mlp = nn.Sequential( + nn.Linear(self.feat_dim, 720), nn.ReLU(), nn.Linear(720, self.feat_dim) + ) + else: + self.input_mlp = None + + self.transformer = Transformer( + self.feat_dim, + depth=num_att_blocks, + heads=6, + dim_head=z_dim // 12, + mlp_dim=z_dim, + selfatt=False, + kv_dim=z_dim, + ) + + if output_mlp: + self.output_mlp = nn.Sequential( + nn.Linear(self.feat_dim, 128), nn.ReLU(), nn.Linear(128, out_dims) + ) + else: + self.output_mlp = None + + def forward(self, z, pixels, keypoints, expression_vector=None): + """ + Args: + z: set-latent scene repres. [batch_size, num_patches, patch_dim] + pixels: query pixels [batch_size, num_pixels, 2] + keypoints: facial query keypoints [batch_size, num_pixels, num_kp, 2] + expression_vector: latent repres. of the query expression [batch_size, expression_size] + """ + bs = pixels.shape[0] + nr = pixels.shape[1] + nkp = keypoints.shape[-2] + queries = self.positional_encoder(pixels, keypoints.view(bs, nr, nkp * 2)) + + if expression_vector is not None: + queries = torch.cat( + [queries, expression_vector[:, None].repeat(1, queries.shape[1], 1)], + dim=-1, + ) + + if self.input_mlp is not None: + queries = self.input_mlp(queries) + + output = self.transformer(queries, z) + + if self.output_mlp is not None: + output = self.output_mlp(output) + + return output + + +class ImprovedFSRTDecoder(nn.Module): + """Scene Representation Transformer Decoder with the improvements from Appendix A.4 in the OSRT paper.""" + + def __init__( + self, + num_att_blocks=2, + pix_octaves=16, + pix_start_octave=-1, + num_kp=10, + kp_octaves=4, + kp_start_octave=-1, + expression_size=0, + ): + super().__init__() + self.allocation_transformer = FSRTPixelPredictor( + num_att_blocks=num_att_blocks, + pix_start_octave=pix_start_octave, + pix_octaves=pix_octaves, + z_dim=768, + input_mlp=True, + output_mlp=False, + expression_size=expression_size, + kp_octaves=kp_octaves, + kp_start_octave=kp_start_octave, + ) + self.expression_size = expression_size + self.feat_dim = pix_octaves * 4 + num_kp * kp_octaves * 4 + self.expression_size + self.render_mlp = nn.Sequential( + nn.Linear(self.feat_dim, 1536), + nn.ReLU(), + nn.Linear(1536, 768), + nn.ReLU(), + nn.Linear(768, 3), + ) + + def forward(self, z, x, pixels, expression_vector=None): + x = self.allocation_transformer( + z, x, pixels, expression_vector=expression_vector + ) + pixels = self.render_mlp(x) + return pixels, {} diff --git a/pyproject.toml b/pyproject.toml index 48056e7..1dce21d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-firstordermm" description = "ComfyUI-native nodes to run First Order Motion Model for Image Animation and its non-diffusion-based successors. [a/https://github.com/AliaksandrSiarohin/first-order-model](https://github.com/AliaksandrSiarohin/first-order-model)" -version = "1.0.5" +version = "1.0.6" license = { file = "LICENSE" } dependencies = ["numpy", "torch", "scipy", "pyyaml", "matplotlib", "einops", "timm"] diff --git a/run.py b/run.py index 9fc0491..29f8342 100644 --- a/run.py +++ b/run.py @@ -12,6 +12,11 @@ from .constants import ( ARTICULATE_CFG_PATH, ARTICULATE_MODEL_PATH, + FSRT_CFG_PATHS, + FSRT_DEFAULT_MODEL, + FSRT_KP_PATH, + FSRT_MODEL_NAMES, + FSRT_MODEL_PATHS, MRFA_CFG_PATHS, MRFA_DEFAULT_MODEL, MRFA_MODEL_NAMES, @@ -33,6 +38,7 @@ from .face_parsing.face_parsing_loader import load_face_parser_model from .inference_articulate import articulate_inference from .inference_fomm import inference, inference_best_frame, load_checkpoint +from .inference_fsrt import fsrt_inference from .inference_mrfa import mrfa_inference from .inference_partswap import load_partswap_checkpoint, partswap_inference from .inference_spline import spline_inference @@ -583,6 +589,100 @@ def todo( frame_rate, ) +class FSRT_Runner: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source_image": ("IMAGE",), + "driving_video_input": ("IMAGE",), + "model_name": (FSRT_MODEL_NAMES, {"default": FSRT_DEFAULT_MODEL}), + "frame_rate": ("FLOAT", {"default": 30.0}), + "relative": ( + "BOOLEAN", + {"default": False}, + ), + "adapt_scale": ( + "BOOLEAN", + {"default": False}, + ), + "find_best_frame": ( + "BOOLEAN", + {"default": False}, + ), + "max_num_pixels": ( + "INT", + {"default": 65536, "min": 1, "step": 1}, + ), + }, + "optional": {"audio": ("AUDIO",)}, + } + + RETURN_TYPES = ( + "IMAGE", + "AUDIO", + "FLOAT", + ) + RETURN_NAMES = ( + "images", + "audio", + "frame_rate", + ) + FUNCTION = "todo" + CATEGORY = "FirstOrderMM" + + def todo( + self, + source_image, + driving_video_input, + model_name: str, + frame_rate: float, + relative: bool, + adapt_scale: bool, + find_best_frame: bool, + max_num_pixels: int, + audio=None, + ): + print(f"{type(source_image)=}") # [B, H, W, C] + print(f"{type(driving_video_input)=}") + print(f"{source_image.shape=}") + print(f"{driving_video_input.shape=}") + print(f"{type(audio)=}") + print(base_dir) + + config_path = f"{base_dir}/{FSRT_CFG_PATHS[model_name]}" + checkpoint_path = f"{base_dir}/{FSRT_MODEL_PATHS[model_name]}" + + source_image = reshape_image(source_image, (256, 256)) + driving_video = reshape_image(driving_video_input, (256, 256)).unsqueeze(0) + driving_video = driving_video.permute(0, 2, 1, 3, 4) + + print("After reshaping") + print(f"{source_image.shape=}") + print(f"{driving_video.shape=}") + params = { + "source_image": source_image, + "driving_video": driving_video, + "config_path": config_path, + "checkpoint_path": checkpoint_path, + "keypoint_path": f"{base_dir}/{FSRT_KP_PATH}", + "relative": relative, + "adapt_scale": adapt_scale, + "find_best_frame": find_best_frame, + "max_num_pixels": max_num_pixels, + } + + predictions = fsrt_inference(**params) + + # output_images = out_video(predictions) + print(f"{predictions[0].shape=}") + output_images = torch.cat(predictions, dim=0) + + return ( + output_images, + audio, + frame_rate, + ) def serialize_integers(int_list): return "_".join(map(str, int_list)) diff --git a/workflows/FSRT.json b/workflows/FSRT.json new file mode 100644 index 0000000..9f1fbae --- /dev/null +++ b/workflows/FSRT.json @@ -0,0 +1,433 @@ +{ + "last_node_id": 8, + "last_link_id": 26, + "nodes": [ + { + "id": 2, + "type": "LoadImage", + "pos": [ + 83, + 105 + ], + "size": { + "0": 315, + "1": 314 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 20 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "Rick1.png", + "image" + ] + }, + { + "id": 5, + "type": "VHS_VideoInfoLoaded", + "pos": [ + 753, + 383 + ], + "size": { + "0": 304.79998779296875, + "1": 106 + }, + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [ + { + "name": "video_info", + "type": "VHS_VIDEOINFO", + "link": 7 + } + ], + "outputs": [ + { + "name": "fps🟦", + "type": "FLOAT", + "links": [ + 26 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "frame_count🟦", + "type": "INT", + "links": null, + "shape": 3 + }, + { + "name": "duration🟦", + "type": "FLOAT", + "links": null, + "shape": 3 + }, + { + "name": "width🟦", + "type": "INT", + "links": null, + "shape": 3 + }, + { + "name": "height🟦", + "type": "INT", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "VHS_VideoInfoLoaded" + }, + "widgets_values": {} + }, + { + "id": 6, + "type": "VHS_VideoCombine", + "pos": [ + 1236, + 163 + ], + "size": [ + 315, + 615 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 21 + }, + { + "name": "audio", + "type": "AUDIO", + "link": 23 + }, + { + "name": "meta_batch", + "type": "VHS_BatchManager", + "link": null + }, + { + "name": "vae", + "type": "VAE", + "link": null + }, + { + "name": "frame_rate", + "type": "FLOAT", + "link": 25, + "widget": { + "name": "frame_rate" + } + } + ], + "outputs": [ + { + "name": "Filenames", + "type": "VHS_FILENAMES", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "VHS_VideoCombine" + }, + "widgets_values": { + "frame_rate": 8, + "loop_count": 0, + "filename_prefix": "FSRT", + "format": "video/h264-mp4", + "pix_fmt": "yuv420p", + "crf": 0, + "save_metadata": true, + "pingpong": false, + "save_output": true, + "videopreview": { + "hidden": false, + "paused": false, + "params": { + "filename": "FSRT_00005-audio.mp4", + "subfolder": "", + "type": "output", + "format": "video/h264-mp4", + "frame_rate": 29.970029970029966 + } + } + } + }, + { + "id": 3, + "type": "VHS_LoadVideo", + "pos": [ + 430, + 196 + ], + "size": [ + 235.1999969482422, + 491.1999969482422 + ], + "flags": { + "collapsed": false + }, + "order": 1, + "mode": 0, + "inputs": [ + { + "name": "meta_batch", + "type": "VHS_BatchManager", + "link": null + }, + { + "name": "vae", + "type": "VAE", + "link": null + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 22 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "frame_count", + "type": "INT", + "links": null, + "shape": 3 + }, + { + "name": "audio", + "type": "AUDIO", + "links": [ + 24 + ], + "shape": 3, + "slot_index": 2 + }, + { + "name": "video_info", + "type": "VHS_VIDEOINFO", + "links": [ + 7 + ], + "shape": 3, + "slot_index": 3 + } + ], + "properties": { + "Node name for S&R": "VHS_LoadVideo" + }, + "widgets_values": { + "video": "damedane.mp4", + "force_rate": 0, + "force_size": "Disabled", + "custom_width": 512, + "custom_height": 512, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1, + "choose video to upload": "image", + "videopreview": { + "hidden": false, + "paused": false, + "params": { + "frame_load_cap": 0, + "skip_first_frames": 0, + "force_rate": 0, + "filename": "damedane.mp4", + "type": "input", + "format": "video/mp4", + "select_every_nth": 1 + } + } + } + }, + { + "id": 8, + "type": "FSRT_Runner", + "pos": [ + 762, + 65 + ], + "size": { + "0": 380.4000244140625, + "1": 218 + }, + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [ + { + "name": "source_image", + "type": "IMAGE", + "link": 20 + }, + { + "name": "driving_video_input", + "type": "IMAGE", + "link": 22 + }, + { + "name": "audio", + "type": "AUDIO", + "link": 24 + }, + { + "name": "frame_rate", + "type": "FLOAT", + "link": 26, + "widget": { + "name": "frame_rate" + } + } + ], + "outputs": [ + { + "name": "images", + "type": "IMAGE", + "links": [ + 21 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "audio", + "type": "AUDIO", + "links": [ + 23 + ], + "shape": 3, + "slot_index": 1 + }, + { + "name": "frame_rate", + "type": "FLOAT", + "links": [ + 25 + ], + "shape": 3, + "slot_index": 2 + } + ], + "properties": { + "Node name for S&R": "FSRT_Runner" + }, + "widgets_values": [ + "vox256", + 30, + false, + false, + false, + 65536 + ] + } + ], + "links": [ + [ + 7, + 3, + 3, + 5, + 0, + "VHS_VIDEOINFO" + ], + [ + 20, + 2, + 0, + 8, + 0, + "IMAGE" + ], + [ + 21, + 8, + 0, + 6, + 0, + "IMAGE" + ], + [ + 22, + 3, + 0, + 8, + 1, + "IMAGE" + ], + [ + 23, + 8, + 1, + 6, + 1, + "AUDIO" + ], + [ + 24, + 3, + 2, + 8, + 2, + "AUDIO" + ], + [ + 25, + 8, + 2, + 6, + 4, + "FLOAT" + ], + [ + 26, + 5, + 0, + 8, + 3, + "FLOAT" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1, + "offset": [ + -23, + 68 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/workflows/MRFA.json b/workflows/MRFA.json index 7b12dbf..4e7d399 100644 --- a/workflows/MRFA.json +++ b/workflows/MRFA.json @@ -1,6 +1,6 @@ { "last_node_id": 7, - "last_link_id": 18, + "last_link_id": 19, "nodes": [ { "id": 2, @@ -171,7 +171,7 @@ "hidden": false, "paused": false, "params": { - "filename": "MRFA_00003.mp4", + "filename": "MRFA_00015-audio.mp4", "subfolder": "", "type": "output", "format": "video/h264-mp4", @@ -180,6 +180,86 @@ } } }, + { + "id": 7, + "type": "MRFA_Runner", + "pos": [ + 779, + 107 + ], + "size": { + "0": 380.4000244140625, + "1": 218 + }, + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [ + { + "name": "source_image", + "type": "IMAGE", + "link": 13 + }, + { + "name": "driving_video_input", + "type": "IMAGE", + "link": 14 + }, + { + "name": "audio", + "type": "AUDIO", + "link": 19 + }, + { + "name": "frame_rate", + "type": "FLOAT", + "link": 18, + "widget": { + "name": "frame_rate" + } + } + ], + "outputs": [ + { + "name": "images", + "type": "IMAGE", + "links": [ + 15 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "audio", + "type": "AUDIO", + "links": [ + 16 + ], + "shape": 3, + "slot_index": 1 + }, + { + "name": "frame_rate", + "type": "FLOAT", + "links": [ + 17 + ], + "shape": 3, + "slot_index": 2 + } + ], + "properties": { + "Node name for S&R": "MRFA_Runner" + }, + "widgets_values": [ + "vox", + 30, + true, + true, + true, + false + ] + }, { "id": 3, "type": "VHS_LoadVideo", @@ -227,7 +307,9 @@ { "name": "audio", "type": "AUDIO", - "links": [], + "links": [ + 19 + ], "shape": 3, "slot_index": 2 }, @@ -268,86 +350,6 @@ } } } - }, - { - "id": 7, - "type": "MRFA_Runner", - "pos": [ - 779, - 107 - ], - "size": { - "0": 380.4000244140625, - "1": 218 - }, - "flags": {}, - "order": 3, - "mode": 0, - "inputs": [ - { - "name": "source_image", - "type": "IMAGE", - "link": 13 - }, - { - "name": "driving_video_input", - "type": "IMAGE", - "link": 14 - }, - { - "name": "audio", - "type": "AUDIO", - "link": null - }, - { - "name": "frame_rate", - "type": "FLOAT", - "link": 18, - "widget": { - "name": "frame_rate" - } - } - ], - "outputs": [ - { - "name": "images", - "type": "IMAGE", - "links": [ - 15 - ], - "shape": 3, - "slot_index": 0 - }, - { - "name": "audio", - "type": "AUDIO", - "links": [ - 16 - ], - "shape": 3, - "slot_index": 1 - }, - { - "name": "frame_rate", - "type": "FLOAT", - "links": [ - 17 - ], - "shape": 3, - "slot_index": 2 - } - ], - "properties": { - "Node name for S&R": "MRFA_Runner" - }, - "widgets_values": [ - "celebvhq", - 30, - true, - true, - true, - true - ] } ], "links": [ @@ -406,6 +408,14 @@ 7, 3, "FLOAT" + ], + [ + 19, + 3, + 2, + 7, + 2, + "AUDIO" ] ], "groups": [], diff --git a/workflows/workflow_fsrt.png b/workflows/workflow_fsrt.png new file mode 100644 index 0000000..195ffa8 Binary files /dev/null and b/workflows/workflow_fsrt.png differ diff --git a/workflows/workflow_mrfa.png b/workflows/workflow_mrfa.png index 6d075e6..813e684 100644 Binary files a/workflows/workflow_mrfa.png and b/workflows/workflow_mrfa.png differ