-
Notifications
You must be signed in to change notification settings - Fork 78
/
test_image_swap_multi.py
45 lines (36 loc) · 1.71 KB
/
test_image_swap_multi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# -*- coding: utf-8 -*-
# @Author: netrunner-exe
# @Date: 2022-12-21 12:52:01
# @Last Modified by: netrunner-exe
# @Last Modified time: 2022-12-21 19:14:34
import logging
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow_addons.layers import InstanceNormalization
from networks.layers import AdaIN, AdaptiveAttention
from retinaface.models import *
from utils.options import FaceDancerOptions
from utils.swap_func import run_inference
logging.getLogger().setLevel(logging.ERROR)
if __name__ == '__main__':
opt = FaceDancerOptions().parse()
if len(tf.config.list_physical_devices('GPU')) != 0:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.set_visible_devices(gpus[opt.device_id], 'GPU')
print('\nInitializing FaceDancer...')
RetinaFace = load_model(opt.retina_path, compile=False,
custom_objects={"FPN": FPN,
"SSH": SSH,
"BboxHead": BboxHead,
"LandmarkHead": LandmarkHead,
"ClassHead": ClassHead})
ArcFace = load_model(opt.arcface_path, compile=False)
G = load_model(opt.facedancer_path, compile=False,
custom_objects={"AdaIN": AdaIN,
"AdaptiveAttention": AdaptiveAttention,
"InstanceNormalization": InstanceNormalization})
G.summary()
print('\nProcessing: {}'.format(opt.img_path))
run_inference(opt, opt.swap_source, opt.img_path,
RetinaFace, ArcFace, G, opt.img_output)
print('\nDone! {}'.format(opt.img_output))