Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/dvschultz/stylegan2-ada int…
Browse files Browse the repository at this point in the history
…o main
  • Loading branch information
Hans committed Nov 15, 2020
2 parents b7db81b + 0787add commit 1f0f929
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 37 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
* **Vertical Mirroring**: use `--mirrory=True` to flip training set top to bottom
* **Interpolations methods**: Multiple interpolation methods included in the generate.py script
* **Neighbor vectors**: Fine-tune seed selections by looking at vectors near it. Included in the generate.py script
* **Use np vectors in interpolations (in addition to seed values)** Use saved .npy or .npz files in interpolation metohds. Thanks @ekkolabs!
* **Flesh Digressions**: @aydao’s circular constant layer script edited to work with ADA see aydao_flesh_digressions.py
* **Raw dataset creations**: Taken from the @skyflynil repo, reduces the size of datasets dramatically. Use `create_from images_raw` and `create_from image_folders_raw` in dataset creation, and use `--use-raw=True` in training (False by default!)
* **align faces script**: From @pbaylies, this script will align images better for projection.

## StyleGAN2 with adaptive discriminator augmentation (ADA)<br>&mdash; Official TensorFlow implementation

Expand Down
15 changes: 11 additions & 4 deletions dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,20 +739,27 @@ def create_from_images_raw(tfrecord_dir, image_dir, shuffle, resolution_log2=7,
if len(image_filenames) == 0:
error("No input images found")

print('loading: ' + image_filenames[0])
img = np.asarray(PIL.Image.open(image_filenames[0]))
resolution = img.shape[0]
print(resolution)
print(2 ** int(np.floor(np.log2(resolution))))
channels = img.shape[2] if img.ndim == 3 else 1
if img.shape[1] != resolution:
error('Input images must have the same width and height')
if resolution != 2 ** int(np.floor(np.log2(resolution))):
error('Input image resolution must be a power-of-two')
# if img.shape[1] != resolution:
# error('Input images must have the same width and height')
# if resolution != 2 ** int(np.floor(np.log2(resolution))):
# error('Input image resolution must be a power-of-two')
if channels not in [1, 3]:
error('Input images must be stored as RGB or grayscale')

with TFRecordExporter(tfrecord_dir, len(image_filenames), resolution_log2=resolution_log2) as tfr:
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
tfr.create_tfr_writer(img.shape)
for idx in range(order.size):
print('loading: ' + image_filenames[order[idx]])
# img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
# if (img.shape[1] != 1024) or (img.shape[0] != 1024):
# error('Input images must have the same width and height')
with tf.gfile.FastGFile(image_filenames[order[idx]], 'rb') as fid:
try:
tfr.add_image_raw(fid.read())
Expand Down
Empty file added ffhq_dataset/__init__.py
Empty file.
92 changes: 92 additions & 0 deletions ffhq_dataset/face_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
import scipy.ndimage
import os
import PIL.Image


def image_align(src_file, dst_file, face_landmarks, output_size=1024, transform_size=4096, enable_padding=True, x_scale=1, y_scale=1, em_scale=0.1, alpha=False):
# Align function from FFHQ dataset pre-processing step
# https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py

lm = np.array(face_landmarks)
lm_chin = lm[0 : 17] # left-right
lm_eyebrow_left = lm[17 : 22] # left-right
lm_eyebrow_right = lm[22 : 27] # left-right
lm_nose = lm[27 : 31] # top-down
lm_nostrils = lm[31 : 36] # top-down
lm_eye_left = lm[36 : 42] # left-clockwise
lm_eye_right = lm[42 : 48] # left-clockwise
lm_mouth_outer = lm[48 : 60] # left-clockwise
lm_mouth_inner = lm[60 : 68] # left-clockwise

# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg

# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
x *= x_scale
y = np.flipud(x) * [-y_scale, y_scale]
c = eye_avg + eye_to_mouth * em_scale
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2

# Load in-the-wild image.
if not os.path.isfile(src_file):
print('\nCannot find source image. Please run "--wilds" before "--align".')
return
img = PIL.Image.open(src_file).convert('RGBA').convert('RGB')

# Shrink.
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, PIL.Image.ANTIALIAS)
quad /= shrink
qsize /= shrink

# Crop.
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]

# Pad.
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
blur = qsize * 0.02
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
img = np.uint8(np.clip(np.rint(img), 0, 255))
if alpha:
mask = 1-np.clip(3.0 * mask, 0.0, 1.0)
mask = np.uint8(np.clip(np.rint(mask*255), 0, 255))
img = np.concatenate((img, mask), axis=2)
img = PIL.Image.fromarray(img, 'RGBA')
else:
img = PIL.Image.fromarray(img, 'RGB')
quad += pad[:2]

# Transform.
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)

# Save aligned image.
img.save(dst_file, 'PNG')
21 changes: 21 additions & 0 deletions ffhq_dataset/landmarks_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import dlib


class LandmarksDetector:
def __init__(self, predictor_model_path):
"""
:param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
"""
self.detector = dlib.get_frontal_face_detector() # cnn_face_detection_model_v1 also can be used
self.shape_predictor = dlib.shape_predictor(predictor_model_path)

def get_landmarks(self, image):
img = dlib.load_rgb_image(image)
dets = self.detector(img, 1)

for detection in dets:
try:
face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()]
yield face_landmarks
except:
print("Exception in get_landmarks()!")
90 changes: 57 additions & 33 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.


"""Generate images using pretrained network pickle."""

import argparse
Expand Down Expand Up @@ -184,19 +185,19 @@ def get_noiseloop(endpoints, nf, d, start_seed):
for f in range(nf):
z = np.random.randn(1, 512)
for i in range(512):
z[0,i] = features[i].get_val(inc*f)
z[0,i] = features[i].get_val(inc*f)
zs.append(z)

return zs

def line_interpolate(zs, steps):
out = []
for i in range(len(zs)-1):
for index in range(steps):
fraction = index/float(steps)
fraction = index/float(steps)
out.append(zs[i+1]*fraction + zs[i]*(1-fraction))
return out

def generate_zs_from_seeds(seeds,Gs):
zs = []
for seed_idx, seed in enumerate(seeds):
Expand All @@ -209,18 +210,18 @@ def convertZtoW(latent, truncation_psi=0.7, truncation_cutoff=9):
dlatent = Gs.components.mapping.run(latent, None) # [seed, layer, component]
dlatent_avg = Gs.get_var('dlatent_avg') # [component]
dlatent = dlatent_avg + (dlatent - dlatent_avg) * truncation_psi

return dlatent

def generate_latent_images(zs, truncation_psi, outdir, save_npy,prefix,vidname,framerate):
Gs_kwargs = {
'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
'randomize_noise': False
}

if not isinstance(truncation_psi, list):
truncation_psi = [truncation_psi] * len(zs)

for z_idx, z in enumerate(zs):
if isinstance(z,list):
z = np.array(z).reshape(1,512)
Expand All @@ -245,7 +246,7 @@ def generate_images_in_w_space(ws, truncation_psi,outdir,save_npy,prefix,vidname
'randomize_noise': False,
'truncation_psi': truncation_psi
}

for w_idx, w in enumerate(ws):
print('Generating image for step %d/%d ...' % (w_idx, len(ws)))
noise_rnd = np.random.RandomState(1) # fix noise
Expand All @@ -261,12 +262,13 @@ def generate_images_in_w_space(ws, truncation_psi,outdir,save_npy,prefix,vidname
def generate_latent_walk(network_pkl, truncation_psi, outdir, walk_type, frames, seeds, npys, save_vector, diameter=2.0, start_seed=0, framerate=24 ):
global _G, _D, Gs, noise_vars
tflib.init_tf()

print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as fp:
_G, _D, Gs = pickle.load(fp)
_G, _D, Gs = pickle.load(fp)

os.makedirs(outdir, exist_ok=True)

# Render images for dlatents initialized from random seeds.
Gs_kwargs = {
'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
Expand All @@ -276,31 +278,38 @@ def generate_latent_walk(network_pkl, truncation_psi, outdir, walk_type, frames,

noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
zs = []

# elif(len(npys) > 0):
# zs = npys

if(len(zs) > 2 ):
print('not enough values to generate walk')
# return false;
ws =[]


# npys specified, let's work with these instead of seeds
# npys must be saved as W's (arrays of 18x512)
if npys and (len(npys) > 0):
ws = npys


wt = walk_type.split('-')

if wt[0] == 'line':
if(len(seeds) > 0):
if seeds and (len(seeds) > 0):
zs = generate_zs_from_seeds(seeds,Gs)

number_of_steps = int(frames/(len(zs)-1))+1

if ws == []:
number_of_steps = int(frames/(len(zs)-1))+1
else:
number_of_steps = int(frames/(len(ws)-1))+1

if (len(wt)>1 and wt[1] == 'w'):
ws = []
for i in range(len(zs)):
ws.append(convertZtoW(zs[i]))
if ws == []:
for i in range(len(zs)):
ws.append(convertZtoW(zs[i]))

points = line_interpolate(ws,number_of_steps)
zpoints = line_interpolate(zs,number_of_steps)

else:
points = line_interpolate(zs,number_of_steps)


# from Gene Kogan
elif wt[0] == 'bspline':
# bspline in w doesnt work yet
Expand Down Expand Up @@ -330,7 +339,12 @@ def generate_latent_walk(network_pkl, truncation_psi, outdir, walk_type, frames,
# ws = []
# for i in enumerate(len(points)):
# ws.append(convertZtoW(points[i]))
seed_out = 'w-' + wt[0] + ('-'.join([str(seed) for seed in seeds]))
#added for npys
if seeds:
seed_out = 'w-' + wt[0] + ('-'.join([str(seed) for seed in seeds]))
else:
seed_out = 'w-' + wt[0] + '-dlatents'

generate_images_in_w_space(points, truncation_psi,outdir,save_vector,'frame', seed_out, framerate)
elif (len(wt)>1 and wt[1] == 'w'):
print('%s is not currently supported in w space, please change your interpolation type' % (wt[0]))
Expand All @@ -348,10 +362,10 @@ def generate_neighbors(network_pkl, seeds, npys, diameter, truncation_psi, num_s
tflib.init_tf()
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as fp:
_G, _D, Gs = pickle.load(fp)
_G, _D, Gs = pickle.load(fp)

os.makedirs(outdir, exist_ok=True)

# Render images for dlatents initialized from random seeds.
Gs_kwargs = {
'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
Expand All @@ -364,7 +378,7 @@ def generate_neighbors(network_pkl, seeds, npys, diameter, truncation_psi, num_s
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx+1, len(seeds)))
rnd = np.random.RandomState(seed)

og_z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
images = Gs.run(og_z, None, **Gs_kwargs) # [minibatch, height, width, channel]
Expand Down Expand Up @@ -559,12 +573,22 @@ def _parse_num_range_ext(s):

def _parse_npy_files(files):
'''Accept a comma separated list of npy files and return a list of z vectors.'''
print(files)

zs =[]

for f in files:
zs.append(np.load(files[f]))


file_list = files.split(",")


for f in file_list:
# load numpy array
arr = np.load(f)
# check if it's actually npz:
if 'dlatents' in arr:
arr = arr['dlatents']
zs.append(arr)



return zs

#----------------------------------------------------------------------------
Expand Down
37 changes: 37 additions & 0 deletions utils/align_faces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import sys
import bz2
from keras.utils import get_file
from ffhq_dataset.face_alignment import image_align
from ffhq_dataset.landmarks_detector import LandmarksDetector

LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'


def unpack_bz2(src_path):
data = bz2.BZ2File(src_path).read()
dst_path = src_path[:-4]
with open(dst_path, 'wb') as fp:
fp.write(data)
return dst_path


if __name__ == "__main__":
"""
Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
python align_images.py /raw_images /aligned_images
"""

landmarks_model_path = unpack_bz2(get_file('shape_predictor_68_face_landmarks.dat.bz2',
LANDMARKS_MODEL_URL, cache_subdir='temp'))
RAW_IMAGES_DIR = sys.argv[1]
ALIGNED_IMAGES_DIR = sys.argv[2]

landmarks_detector = LandmarksDetector(landmarks_model_path)
for img_name in [x for x in os.listdir(RAW_IMAGES_DIR) if x[0] not in '._']:
raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name)
for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1):
face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
os.makedirs(ALIGNED_IMAGES_DIR, exist_ok=True)
image_align(raw_img_path, aligned_face_path, face_landmarks)
Loading

0 comments on commit 1f0f929

Please sign in to comment.