Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update inference.py #103

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 111 additions & 89 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
import time
import os
import logging
from tqdm import tqdm

from utils.inference.image_processing import crop_face, get_final_image
from utils.inference.video_processing import read_video, get_target, get_final_video, add_audio_from_another_video, face_enhancement
Expand All @@ -17,137 +19,157 @@
from models.config_sr import TestOptions


# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def init_models(args):
# model for face cropping
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640, 640))

# main model for generation
G = AEI_Net(args.backbone, num_blocks=args.num_blocks, c_id=512)
G.eval()
G.load_state_dict(torch.load(args.G_path, map_location=torch.device('cpu')))
G = G.cuda()
G = G.half()
G = G.cuda().half()

# arcface model to get face embedding
netArc = iresnet100(fp16=False)
netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))
netArc=netArc.cuda()
netArc = netArc.cuda()
netArc.eval()

# model to get face landmarks
handler = Handler('./coordinate_reg/model/2d106det', 0, ctx_id=0, det_size=640)

# model to make superres of face, set use_sr=True if you want to use super resolution or use_sr=False if you don't
# model to make superres of face, set use_sr=True if you want to use super resolution
model = None
if args.use_sr:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.backends.cudnn.benchmark = True
opt = TestOptions()
#opt.which_epoch ='10_7'
model = Pix2PixModel(opt)
model.netG.train()
else:
model = None

model.netG.eval() # Ensure the model is in evaluation mode

return app, G, netArc, handler, model


def main(args):
app, G, netArc, handler, model = init_models(args)

# get crops from source images
print('List of source paths: ',args.source_paths)
source = []
try:
for source_path in args.source_paths:
img = cv2.imread(source_path)
img = crop_face(img, app, args.crop_size)[0]
source.append(img[:, :, ::-1])
except TypeError:
print("Bad source images!")
exit()

# get full frames from video
if not args.image_to_image:
full_frames, fps = read_video(args.target_video)
else:
target_full = cv2.imread(args.target_image)
full_frames = [target_full]

# get target faces that are used for swap


def load_faces(paths, app, crop_size):
faces = []
for path in paths:
try:
img = cv2.imread(path)
cropped_faces = crop_face(img, app, crop_size)
if not cropped_faces:
raise ValueError(f"No face detected in {path}")
faces.append(cropped_faces[0][:, :, ::-1]) # First face detected
except Exception as e:
logging.error(f"Error loading face from {path}: {e}")
return faces


def process_source_and_target(args, app):
# Load and process source images
logging.info("Processing source images...")
source = load_faces(args.source_paths, app, args.crop_size)

# Load and process target images (either from file or video frames)
set_target = True
print('List of target paths: ', args.target_faces_paths)
target = []
if not args.target_faces_paths:
logging.info("No target faces provided, selecting faces from the video...")
full_frames, _ = read_video(args.target_video)
target = get_target(full_frames, app, args.crop_size)
set_target = False
else:
target = []
try:
for target_faces_path in args.target_faces_paths:
img = cv2.imread(target_faces_path)
img = crop_face(img, app, args.crop_size)[0]
target.append(img)
except TypeError:
print("Bad target images!")
exit()

start = time.time()
final_frames_list, crop_frames_list, full_frames, tfm_array_list = model_inference(full_frames,
source,
target,
netArc,
G,
app,
set_target,
similarity_th=args.similarity_th,
crop_size=args.crop_size,
BS=args.batch_size)
if args.use_sr:
final_frames_list = face_enhancement(final_frames_list, model)
logging.info("Processing target face images...")
target = load_faces(args.target_faces_paths, app, args.crop_size)

return source, target, set_target


def perform_inference(source, target, full_frames, app, G, netArc, args):
# Run inference for face swapping
with torch.no_grad():
final_frames_list, crop_frames_list, full_frames, tfm_array_list = model_inference(
full_frames, source, target, netArc, G, app,
set_target=True, similarity_th=args.similarity_th,
crop_size=args.crop_size, BS=args.batch_size
)
return final_frames_list, crop_frames_list, full_frames, tfm_array_list


def enhance_faces(final_frames_list, model):
# Enhance faces with super resolution if enabled
logging.info("Enhancing faces using super resolution...")
return face_enhancement(final_frames_list, model)


def save_output(final_frames_list, crop_frames_list, full_frames, tfm_array_list, args):
if not args.image_to_image:
get_final_video(final_frames_list,
crop_frames_list,
full_frames,
tfm_array_list,
args.out_video_name,
fps,
handler)

logging.info(f"Saving output video to {args.out_video_name}...")
get_final_video(final_frames_list, crop_frames_list, full_frames, tfm_array_list, args.out_video_name, fps, handler)
add_audio_from_another_video(args.target_video, args.out_video_name, "audio")
print(f"Video saved with path {args.out_video_name}")
else:
logging.info(f"Saving output image to {args.out_image_name}...")
result = get_final_image(final_frames_list, crop_frames_list, full_frames[0], tfm_array_list, handler)
cv2.imwrite(args.out_image_name, result)
print(f'Swapped Image saved with path {args.out_image_name}')

print('Total time: ', time.time()-start)


logging.info("Output saved successfully.")


def main(args):
# Initialize models and parameters
app, G, netArc, handler, model = init_models(args)

# Process source and target images
source, target, set_target = process_source_and_target(args, app)

# Get full frames from video or load target image
start_time = time.time()
logging.info("Starting face swapping...")

# Perform inference
full_frames, fps = read_video(args.target_video) if not args.image_to_image else [cv2.imread(args.target_image)]
final_frames_list, crop_frames_list, full_frames, tfm_array_list = perform_inference(
source, target, full_frames, app, G, netArc, args
)

# Enhance faces if required
if args.use_sr:
final_frames_list = enhance_faces(final_frames_list, model)

# Save the output
save_output(final_frames_list, crop_frames_list, full_frames, tfm_array_list, args)

logging.info(f"Total time: {time.time() - start_time:.2f} seconds")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

# Generator params
parser.add_argument('--G_path', default='weights/G_unet_2blocks.pth', type=str, help='Path to weights for G')
parser.add_argument('--backbone', default='unet', const='unet', nargs='?', choices=['unet', 'linknet', 'resnet'], help='Backbone for attribute encoder')
parser.add_argument('--num_blocks', default=2, type=int, help='Numbers of AddBlocks at AddResblock')
parser.add_argument('--num_blocks', default=2, type=int, help='Number of AddBlocks at AddResblock')

parser.add_argument('--batch_size', default=40, type=int)
parser.add_argument('--crop_size', default=224, type=int, help="Don't change this")
parser.add_argument('--use_sr', default=False, type=bool, help='True for super resolution on swap images')
parser.add_argument('--similarity_th', default=0.15, type=float, help='Threshold for selecting a face similar to the target')

parser.add_argument('--source_paths', default=['examples/images/mark.jpg', 'examples/images/elon_musk.jpg'], nargs='+')
parser.add_argument('--target_faces_paths', default=[], nargs='+', help="It's necessary to set the face/faces in the video to which the source face/faces is swapped. You can skip this parametr, and then any face is selected in the target video for swap.")
# parameters for image to video
parser.add_argument('--target_video', default='examples/videos/nggyup.mp4', type=str, help="It's necessary for image to video swap")
parser.add_argument('--out_video_name', default='examples/results/result.mp4', type=str, help="It's necessary for image to video swap")
# parameters for image to image
parser.add_argument('--image_to_image', default=False, type=bool, help='True for image to image swap, False for swap on video')
parser.add_argument('--target_image', default='examples/images/beckham.jpg', type=str, help="It's necessary for image to image swap")
parser.add_argument('--out_image_name', default='examples/results/result.png', type=str,help="It's necessary for image to image swap")
parser.add_argument('--target_faces_paths', default=[], nargs='+', help="List of target face images")

# Parameters for image to video
parser.add_argument('--target_video', default='examples/videos/nggyup.mp4', type=str, help="Target video for swapping faces")
parser.add_argument('--out_video_name', default='examples/results/result.mp4', type=str, help="Output video name")

# Parameters for image to image
parser.add_argument('--image_to_image', default=False, type=bool, help='True for image to image swap, False for image to video swap')
parser.add_argument('--target_image', default='examples/images/beckham.jpg', type=str, help="Target image for swapping faces")
parser.add_argument('--out_image_name', default='examples/results/result.png', type=str, help="Output image name")

args = parser.parse_args()
main(args)
main(args)