Generating faces with specific attributes
The problem of conditional image generation consists in generating a set of images, given a specific set of attributes, such as a particular color or visual feature.
In our case, we are interested in generating human faces, given a specific set of attributes found in CelebA dataset. In order to do so, we are going to use a particular kind of generative models called generative adversarial networks (GANs).
The mathematical details of how GANs work will be omitted in this post, which will instead focus on the programming side of our task.
GANs
Evolution of GAN generated samples in last couple of years. Credits to Ian Goodfellow.
Generative adversarial networks, are currently one of the most used approaches for image generation, as they allow to (better say “try to”) generate sharp and realisticly looking images.
The realistic look of the images comes from the training architecture, which employs two neural networks, generator and discriminator, that are trained alternatively.
The role of generator consists in generating images from a given input noise, while the discriminator is used to classify the incoming images as being real (coming from training dataset) or being fake (generated by generator). The more the training goes on, the more the generator will try to improve at generating better looking samples.
The problem with all this however, is that one of these two components might train better than the other, causing instability, yielding not desired images.
Preparations
Preparing environment
As our programming language we are going to use Python3, with Tensorflow 2.2 (or higher) and it’s high level Keras API with some additional libraries.
Note: We’ll use Anaconda that will allow us to easily install all the dependencies, as Tensorflow GPU has some annoying manual dependencies.
It is a good and useful practice to install the libraries into a separate virtual environment (to avoid messing the global libraries), thus let us use Anaconda to create an environment named tf2_faces and install all the needed libraries:
conda create --name tf2_faces
Next, we’ll activate the environment and install the necessary dependencies:
conda activate tf2_faces
conda install pip
conda install -c anaconda tensorflow-gpu
conda install -c anaconda pandas
conda install -c conda-forge opencv
conda install -c conda-forge matplotlib
Now our environment is ready to be used, however we still miss our training data.
Downloading data
We are going to use CelebA dataset, which contains around 200,000 face images annotated with 40 attributes, (such as beard/nobeard, hair/nohair etc.).
CelebA dataset preview from the original authors site.
For our task, make sure to download Anno/list_attr_celeba.txt and Img/img_align_celeba.zip files from the following CelebA drive folder.
Next, extract the images zip archive in a folder and put the list_attr_celeba.txt file in the same folder:
mkdir CelebA
mv img_align_celeba.zip CelebA/img_align_celeba.zip
mv list_attr_celeba.txt CelebA/list_attr_celeba.txt
cd CelebA
unzip img_align_celeba.zip
Our data is now also ready to be used.
Coding
Here we are going to present our code for loading data, defining and training our model.
Data loading and preview
Let us first define a function to load a single image, crop, resize to 64x64 and normalize it into [-1, 1] range, as usually done in GANs:
import cv2
import numpy as np
def load_image(image_path, resize_size, crop_pt_1=None, crop_pt_2=None):
"""
Loads an image from a given path, crops it around the rectangle defined by two points, and resizes it.
Inputs:
image_path: path of the image
resize_size: new size of the output image
crop_pt_1: first point of crop
crop_pt_2: second point of crop
Returns:
image: resized image and rescaled to [-1, 1] interval
"""
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
shape = image.shape
if crop_pt_1 is None:
crop_pt_1 = (0, 0)
if crop_pt_2 is None:
crop_pt_2 = (shape[0], shape[1])
image = image[crop_pt_1[0]:crop_pt_2[0], crop_pt_1[1]:crop_pt_2[1]]
resized = cv2.resize(image, resize_size)
resized = resized.astype(np.float32)
return (resized - 127.5) / 127.5
This function will be useful as it will do all the necessary preprocessing for our data.
Then we’ll need to define a class, subclass of tensorflow.keras.utils.Sequence, that will easily allow us to dynamically load batches of images in a memory-efficient way (i.e. avoid loading all the image data at once):
import math
import os
import random
import numpy as np
from tensorflow.keras.utils import Sequence
class DataSequence(Sequence):
"""
Keras Sequence object to train a model on larger-than-memory data.
"""
def __init__(self, df, data_root, batch_size, resize_size=(64, 64), flip_augment=True, mode='train'):
self.df = df
self.batch_size = batch_size
self.mode = mode
self.resize_size = resize_size
self.crop_pt_1 = (45, 25)
self.crop_pt_2 = (173, 153)
self.flip_augment = flip_augment
# Extract columns from df columns
self.label_columns = self.df.columns[1:].tolist()
# Take labels and a list of image locations in memory
self.labels = self.df[self.label_columns].values
self.im_list = self.df['Image_Name'].apply(lambda x: os.path.join(data_root, x)).tolist()
# Trigger a shuffle
self.on_epoch_end()
def __len__(self):
return int(math.floor(len(self.df) / float(self.batch_size)))
def on_epoch_end(self):
# Shuffles indexes after each epoch if in training mode
self.indexes = range(len(self.im_list))
if self.mode == 'train':
self.indexes = random.sample(self.indexes, k=len(self.indexes))
def get_batch_labels(self, idx):
# Fetch a batch of labels
return self.labels[idx]
def get_batch_features(self, idx):
images = []
for im_idx in idx:
im = self.im_list[im_idx]
loaded_image = load_image(im, self.resize_size, self.crop_pt_1, self.crop_pt_2)
if self.flip_augment and random.random() < 0.5:
loaded_image = np.flip(loaded_image, 1)
images.append(loaded_image)
# Fetch a batch of inputs
return np.array(images)
def __getitem__(self, index):
idx = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Get the actual data
batch_x = self.get_batch_features(idx)
batch_y = np.clip(self.get_batch_labels(idx).astype(np.float32), 0, 1)
return (batch_x, batch_y), batch_y
Note: for a better explaination what each DataSequence’s code part, check this post by Afshine Amidi and Shervine Amidi.
In order to test the above code, we first define a function that will allow us to visualize image data as a grid:
import matplotlib.pyplot as plt
def display_images(image_batch, cols=4, rows=8):
height, width = image_batch.shape[1:3]
reshaped = np.reshape(image_batch, (rows, cols, height, width, 3)).transpose(0, 2, 1, 3, 4).reshape(rows*height, cols * width, 3)
plt.imshow(reshaped)
plt.axis('off')
plt.show()
Finally, we can test our DataSequence by loading a batch of images and displaying it using our freshly defined display_images function:
import pandas as pd
train_data_df = pd.read_csv("CelebA/list_attr_celeba.txt")
training_generator = DataSequence(train_data_df, "CelebA/img_align_celeba", batch_size=32)
(images, labels), labels = next(iter(training_generator))
display_images(images)
This should produce a similar figure with 32 faces, whose intensity got scaled to be in [-1, 1] range (hence why they look dark):
We are now ready to define our model.
Model definition
The model we’ll be using, is simple DCGAN model with pixelwise normalization, spectral normalization and hinge loss, which will allow us stabilize the training, as one of the main GAN problems is the training stability (i.e. they might start generating nonsense images as training proceeds).
Custom Layers
First we need to define custom Keras layers with spectral normalization:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.python.keras.utils import conv_utils
#epsilon set according to BIGGAN https://arxiv.org/pdf/1809.11096.pdf
def _l2normalizer(v, epsilon=1e-4):
return v / (K.sum(v**2)**0.5 + epsilon)
def power_iteration(W, u, rounds=1):
'''
Accroding the paper, we only need to do power iteration one time.
'''
_u = u
for i in range(rounds):
_v = _l2normalizer(K.dot(_u, W))
_u = _l2normalizer(K.dot(_v, K.transpose(W)))
W_sn = K.sum(K.dot(_u, W) * _v)
return W_sn, _u, _v
The above code performs imports and defines an utility method, while below we define the actual layers for spectral normalized 2D convolution:
class SNConv2D(tf.keras.layers.Conv2D):
def __init__(self, filters, spectral_normalization=True, **kwargs):
self.spectral_normalization = spectral_normalization
super(SNConv2D, self).__init__(filters, **kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.u = self.add_weight(name='u', shape=(1, self.filters),
initializer='uniform', trainable=False)
super(SNConv2D, self).build(input_shape)
# Be sure to call this at the end
def compute_spectral_normal(self, training=True):
# Spectrally Normalized Weight
if self.spectral_normalization:
# Get kernel tensor shape [kernel_h, kernel_w, in_channels, out_channels]
W_shape = self.kernel.shape.as_list()
# Flatten the Tensor
# [out_channels, N]
W_mat = K.reshape(self.kernel, [W_shape[-1], -1])
W_sn, u, v = power_iteration(W_mat, self.u)
if training:
# Update estimated 1st singular vector
self.u.assign(u)
return self.kernel / W_sn
else:
return self.kernel
def call(self, inputs, training=None):
outputs = K.conv2d(inputs,
self.compute_spectral_normal(training=training),
strides=self.strides, padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if self.use_bias:
outputs = K.bias_add(outputs, self.bias,
data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
def compute_output_shape(self, input_shape):
return super(SNConv2D, self).compute_output_shape(input_shape)
and spectrally normalized 2D deconvolution:
class SNConv2DTranspose(tf.keras.layers.Conv2DTranspose):
def __init__(self, spectral_normalization=True, **kwargs):
self.spectral_normalization = spectral_normalization
super(SNConv2DTranspose, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.u = self.add_weight(name='u', shape=(1, self.filters),
initializer='uniform', trainable=False)
super(SNConv2DTranspose, self).build(input_shape)
# Be sure to call this at the end
def compute_spectral_normal(self, training=True):
# Spectrally Normalized Weight
if self.spectral_normalization:
# Get kernel tensor shape [kernel_h, kernel_w, in_channels, out_channels]
W_shape = self.kernel.shape.as_list()
# Flatten the Tensor
# [out_channels, N]
W_mat = K.reshape(self.kernel, [W_shape[-2], -1])
W_sn, u, v = power_iteration(W_mat, self.u)
if training:
# Update estimated 1st singular vector
self.u.assign(u)
return self.kernel / W_sn
else:
return self.kernel
def call(self, inputs, training=None):
input_shape = K.shape(inputs)
batch_size = input_shape[0]
if self.data_format == 'channels_first':
h_axis, w_axis = 2, 3
else:
h_axis, w_axis = 1, 2
height, width = input_shape[h_axis], input_shape[w_axis]
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
out_pad_h = out_pad_w = None
# Infer the dynamic output shape:
out_height = conv_utils.deconv_output_length(
height, kernel_h, self.padding, stride=stride_h)
out_width = conv_utils.deconv_output_length(
width, kernel_w, self.padding, stride=stride_w)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width)
else:
output_shape = (batch_size, out_height, out_width, self.filters)
outputs = K.conv2d_transpose(
inputs,
self.compute_spectral_normal(training=training),
output_shape,
self.strides,
padding=self.padding,
data_format=self.data_format
)
if self.use_bias:
outputs = K.bias_add(outputs, self.bias,
data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
def compute_output_shape(self, input_shape):
return super(SNConv2DTranspose, self).compute_output_shape(input_shape)
Don’t mind the overwhelming code, as it is just a wrapper around the existing Keras’s 2D convolution and deconvolution implementations.
Next, we’ll define PixelwiseNormalization layer, which was first defined in ProgressiveGAN to add an additional source of stabilization:
class PixelNormalization(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(PixelNormalization, self).__init__(**kwargs)
def call(self, inputs):
# Calculate square pixel values
values = inputs**2.0
# Calculate the mean pixel values
mean_values = tf.keras.backend.mean(values, axis=-1, keepdims=True)
# Ensure the mean is not zero
mean_values += 1.0e-8
# Calculate the sqrt of the mean squared value (L2 norm)
l2 = tf.keras.backend.sqrt(mean_values)
# Normalize values by the l2 norm
normalized = inputs / l2
return normalized
def compute_output_shape(self, input_shape):
return input_shape
Models
We finally can define our GAN model’s components: generator, discriminator, wrapper model.
Generator and discriminator models follow a standard DCGAN architecture with the addition of pixelwise normalization layers instead of batch normalization and spectrally normalized convolutions.
Our generator will have an additional input vector for image conditions:
class Generator(tf.keras.Model):
def __init__(self, latent_dim, filters=64, kernel_size=4, cond_dim=40):
super(Generator, self).__init__()
self.filters = filters
self.latent_dim = latent_dim
self.dense1 = tf.keras.layers.Dense(4 * 4 * filters * 8, input_shape=(latent_dim + cond_dim,))
self.relu1 = tf.keras.layers.ReLU()
self.reshape1 = tf.keras.layers.Reshape((4, 4, filters * 8))
self.bn1 = PixelNormalization()
# 4x4 -> 8x8
self.block1_upscale = tf.keras.layers.UpSampling2D()
self.block1_conv1 = SNConv2D(
filters=filters * 4, kernel_size=kernel_size, strides=(1, 1),
padding="same")
self.block1_bn = PixelNormalization()
self.block1_relu1 = tf.keras.layers.ReLU()
# 8x8 -> 16x16
self.block2_upscale = tf.keras.layers.UpSampling2D()
self.block2_conv1 = SNConv2D(
filters=filters * 2, kernel_size=kernel_size, strides=(1, 1),
padding="same")
self.block2_bn = PixelNormalization()
self.block2_relu1 = tf.keras.layers.ReLU()
# 16x16 -> 32x32
self.block3_upscale = tf.keras.layers.UpSampling2D()
self.block3_conv1 = SNConv2D(
filters=filters, kernel_size=kernel_size, strides=(1, 1),
padding="same")
self.block3_bn = PixelNormalization()
self.block3_relu1 = tf.keras.layers.ReLU()
# 32x32 -> 64x64
self.block4_upscale = tf.keras.layers.UpSampling2D()
self.block4_conv1 = SNConv2D(
filters=filters // 2, kernel_size=kernel_size, strides=(1, 1),
padding="same")
self.block4_bn = PixelNormalization()
self.block4_relu1 = tf.keras.layers.ReLU()
# 64 x 64 x FILTERS -> 64 x 64 x 3
self.block5_conv1 = SNConv2D(
filters=3, kernel_size=4, strides=(1, 1),
padding="same", activation="tanh")
@tf.function
def call(self, z, conditions, training=False):
x = tf.concat([z, conditions], axis=-1)
x = self.dense1(x)
x = self.relu1(x)
x = self.reshape1(x)
x = self.bn1(x, training=training)
x = self.block1_upscale(x)
x = self.block1_conv1(x)
x = self.block1_bn(x, training=training)
x = self.block1_relu1(x)
x = self.block2_upscale(x)
x = self.block2_conv1(x)
x = self.block2_bn(x, training=training)
x = self.block2_relu1(x)
x = self.block3_upscale(x)
x = self.block3_conv1(x)
x = self.block3_bn(x, training=training)
x = self.block3_relu1(x)
x = self.block4_upscale(x)
x = self.block4_conv1(x)
x = self.block4_bn(x, training=training)
x = self.block4_relu1(x)
images = self.block5_conv1(x)
return images
Discriminator will output both real/generated score and a classification score for CelebA attributes:
class Discriminator(tf.keras.Model):
def __init__(self, filters=64, kernel_size=3, n_classes=40):
super(Discriminator, self).__init__()
self.filters = filters
# 64 x 64 x FILTERS
self.block1_conv1 = SNConv2D(filters, kernel_size=kernel_size, strides=(2, 2), padding="same")
self.block1_lrelu1 = tf.keras.layers.LeakyReLU(alpha=0.1)
# 32 x 32 x FILTERS
self.block2_conv1 = SNConv2D(filters * 2, kernel_size=kernel_size, strides=(2, 2), padding="same")
self.block2_lrelu1 = tf.keras.layers.LeakyReLU(alpha=0.1)
# 16 x 16 x FILTERS
self.block3_conv1 = SNConv2D(filters * 4, kernel_size=kernel_size, strides=(2, 2), padding="same")
self.block3_lrelu1 = tf.keras.layers.LeakyReLU(alpha=0.1)
# 8 x 8 x FILTERS
self.block4_conv1 = SNConv2D(filters * 8, kernel_size=kernel_size, strides=(2, 2), padding="same")
self.block4_lrelu1 = tf.keras.layers.LeakyReLU(alpha=0.1)
# Current size: 4 x 4 x FILTERS
self.avg_pool = tf.keras.layers.GlobalAvgPool2D()
self.scoring = tf.keras.layers.Dense(1)
self.classifier = tf.keras.layers.Dense(n_classes, activation="sigmoid")
@tf.function
def call(self, images, training=False):
x = images
x = self.block1_conv1(x)
x = self.block1_lrelu1(x)
x = self.block2_conv1(x)
x = self.block2_lrelu1(x)
x = self.block3_conv1(x)
x = self.block3_lrelu1(x)
x = self.block4_conv1(x)
x = self.block4_lrelu1(x)
x = self.avg_pool(x)
scores = self.scoring(x)
classes = self.classifier(x)
return scores, classes
The wrapper model is used with custom train_step to allow training via Keras fit method, as shown in DCGAN Keras example:
class GAN_Wrapper(tf.keras.Model):
def __init__(self, discriminator, generator, **kwargs):
super(GAN_Wrapper, self).__init__(**kwargs)
self.discriminator = discriminator
self.generator = generator
self.latent_dim = self.generator.latent_dim
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
super(GAN_Wrapper, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.d_classifier_loss = tf.keras.losses.BinaryCrossentropy()
self.d_loss_fn = d_loss_fn
self.g_loss_fn = g_loss_fn
def call(self, x, training=False):
"""
This method is overridden only because it is required by tf.keras.Model
"""
pass
@tf.function
def train_step(self, data):
# Unpack images and their labels
(real_images, conditions), _ = data
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
# Train the discriminator
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
with tf.GradientTape() as tape:
# Generate fake images from the latent vector
fake_images = self.generator(random_latent_vectors, conditions, training=False)
# Get the logits for the fake images
fake_logits, fake_classes = self.discriminator(fake_images, training=True)
# Get the logits for real images
real_logits, real_classes = self.discriminator(real_images, training=True)
# Calculate discriminator loss using fake and real logits
d_cost = self.d_loss_fn(logits_real=real_logits, logits_fake=fake_logits)
d_class_real = self.d_classifier_loss(y_true=conditions, y_pred=real_classes)
# Add the class loss to the original discriminator loss
d_loss = d_cost + d_class_real
# Get the gradients w.r.t the discriminator loss
d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
# Update the weights of the discriminator using the discriminator optimizer
self.d_optimizer.apply_gradients(
zip(d_gradient, self.discriminator.trainable_variables)
)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Train the generator
with tf.GradientTape() as tape:
# Generate fake images using the generator
generated_images = self.generator(random_latent_vectors, conditions, training=True)
# Get the discriminator logits for fake images
gen_img_logits, gen_img_classes = self.discriminator(generated_images, training=False)
# Calculate the generator loss
g_cost = self.g_loss_fn(gen_img_logits)
g_class_gen = self.d_classifier_loss(y_true=conditions, y_pred=gen_img_classes)
g_loss = g_cost + g_class_gen
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
return {
"d_loss": d_loss,
"g_loss": g_loss,
"D(G(z))": tf.reduce_mean(gen_img_logits),
"D(x)": tf.reduce_mean(real_logits)
}
Having defined our model and image data loader, we can finally proceed to train it.
Training
First, we declare our training parameters:
# Fit parameters
batch_size = 32
n_epochs = 30
# Generator parameters
latent_dim = 128
conditional_dim = 40
filters_gen = 64
kernel_size_gen = 3
# Discriminator parameters
filters_disc = 64
kernel_size_disc = 3
Then, we instantiate Discriminator, Generator and GAN_Wrapper objects:
# Create models
generator_model = Generator(latent_dim, filters=filters_gen, kernel_size=kernel_size_gen)
discriminator_model = Discriminator(filters=filters_disc, kernel_size=kernel_size_disc)
# Create gan wrapper model
gan_model = GAN_Wrapper(discriminator_model, generator_model)
the hinge loss:
@tf.function
def discriminator_loss(logits_real, logits_fake):
real_loss = tf.keras.backend.mean(tf.keras.backend.relu(1 - logits_real))
fake_loss = tf.keras.backend.mean(tf.keras.backend.relu(1 + logits_fake))
return (fake_loss + real_loss) / 2
@tf.function
def generator_loss(logits_fake):
return -tf.keras.backend.mean(logits_fake)
and compile:
gan_model.compile(
d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.00005, beta_1=0.0),
g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.0),
d_loss_fn=discriminator_loss,
g_loss_fn=generator_loss,
)
Finally, we can perform the actual fitting of our GAN:
history = gan_model.fit(training_generator,
use_multiprocessing=True,
workers=8,
epochs=n_epochs,
callbacks=[
tf.keras.callbacks.ModelCheckpoint(filepath='saved_models/model_{epoch}.h5')
]
)
The above code, will allow us to train a GAN model to generate human faces, having specific visual attributes.
Notice that we used ModelCheckpoint callback, in order to save our model every epoch. If we also wanted to save images, we could add the following callback, that generates images with a given conditions set, to the callbacks list:
ImagesLoggingCallback(generator_model, latent_dim, labels, "images")
Where ImagesLoggingCallback is defined as:
class ImagesLoggingCallback(tf.keras.callbacks.Callback):
def __init__(self, generator, latent_dim, view_conditions, images_dir, n_images=32, rows=4, cols=8):
super(ImagesLoggingCallback, self).__init__()
# Note: we can access self.model.generator also
self.generator = generator
self.n_images = n_images
self.rows = rows
self.cols = cols
self.latent_dim = latent_dim
self.images_dir = images_dir
self.view_cond = view_conditions
self.random_latent_vectors = tf.random.normal(shape=(self.n_images, self.latent_dim))
def on_epoch_begin(self, epoch, logs=None):
# Generate images
generated_images = self.generator(self.random_latent_vectors, self.view_conditions, training=False)
generated_images = (generated_images + 1) / 2.0
generated_images.numpy()
# Save the figure
fig = plt.figure()
height, width = generated_images.shape[1:3]
reshaped = np.reshape(generated_images, (self.rows, self.cols, height, width, 3)).transpose(0, 2, 1, 3, 4).reshape(self.rows*height, self.cols * width, 3)
plt.imshow(reshaped)
plt.axis('off')
fig.savefig(self.images_dir+"/sample_{}.png".format(epoch))
plt.close(fig)
Results
The overall training should take around an hour on a GTX1080 GPU, and produce results similar to the following:
We can play with conditional labels, and make every generated sample blonde:
And to make everyone smile:
Improvements
As seen in the images above, our model can generate faces that resemble human ones, however most of them look unrealistic.
One possible solution might always be to train the model more, as we’ve trained for only 30 epochs.
Another, simple solution, is the one we are going to describe next.
EMA
Instead of a training a single generator model, whose updates might still cause some instability (producing not so smooth images), we can jointly train a second generator, that will be updated via exponential moving average (EMA) rather than with gradient methods.
More precisely, let us G be a generator model with weights u, our training procedure will:
- Clone the generator model, producing G’.
- Update G’ weights using EMA: wt+1 = (1 - b) * ut + (b) * wt
The value of b is chosen such that 1 - b quantity will be close to zero, for example b=0.9999.
Let us define an EMA update function, which will take two models as input, and update the second one using a small fraction of first’s weights:
def ema_update(model, model_ema, beta=0.9999):
"""
Performs a model update by using exponential moving average (EMA)
of first model's weights to update second model's weights
as defined in "The Unusual Effectiveness of Averaging in GAN Training" (https://arxiv.org/abs/1806.04498),
realizing the following update:
model_ema.weights = beta * model_ema.weights + (1 - beta) * model.weights
:param model: original, gradient descent trained model.
:param model_ema: clone of original model, that will get updated using EMA.
:param beta: EMA update weight, determines what portion of model_ema weights should be kept (default=0.9999).
"""
# for each model layer index
for i in range(len(model.layers)):
updating_weights = model.layers[i].get_weights() # original model's weights
ema_old_weights = model_ema.layers[i].get_weights() # ema model's weights
ema_new_weights = [] # ema model's update weights
if len(updating_weights) != len(ema_old_weights):
# weight list length mismatch between model's weights and ema model's weights
print("Different weight length")
# copy ema weights directly from the model's weights
ema_new_weights = updating_weights
else:
# for each weight tensor of original model's weights list
for j in range(len(updating_weights)):
n_weight = beta * ema_old_weights[j] + (1 - beta) * updating_weights[j]
ema_new_weights.append(n_weight)
# update weights
model_ema.layers[i].set_weights(ema_new_weights)
Then the actual wrapper for our EMA model, which will be a subclass of the original GAN_Wrapper:
class GAN_WrapperEMA(GAN_Wrapper):
def __init__(self, discriminator, generator, generator_ema, beta=0.9999):
super(GAN_WrapperEMA, self).__init__(discriminator, generator)
self.generator_ema = generator_ema
self.generator_ema.set_weights(self.generator.get_weights())
self.beta = beta
self.current_step = tf.Variable(initial_value=0, trainable=False)
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
super(GAN_WrapperEMA, self).compile(d_optimizer, g_optimizer, d_loss_fn, g_loss_fn)
def ema_step(self):
ema_update(self.generator, self.generator_ema, beta=self.beta)
def call(self, x, training=False):
"""
This method is overridden only because it is required by tf.keras.Model
"""
pass
@tf.function
def train_step(self, data):
stats_dict = super(GAN_WrapperEMA, self).train_step(data)
self.current_step.assign_add(1)
return stats_dict
Now, we need to instantiate again all of the objects:
# EMA params
ema_beta = 0.9999
ema_start = 10
ema_every = 2
# create models
generator_model = Generator(latent_dim, filters=filters_gen, kernel_size=kernel_size_gen)
generator_model_ema = Generator(latent_dim, filters=filters_gen, kernel_size=kernel_size_gen)
discriminator_model = Discriminator(filters=filters_disc, kernel_size=kernel_size_disc)
# create gan wrapper model
gan_model = GAN_WrapperEMA(discriminator_model, generator_model, generator_model_ema, beta=ema_beta)
# images callback
image_save_callback = ImagesLoggingCallback(generator_model_ema, latent_dim, labels, "images")
# compile model
gan_model.compile(
d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.00005, beta_1=0.0),
g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.0),
d_loss_fn=discriminator_loss,
g_loss_fn=generator_loss,
)
As we can see from the code definitions, GAN_WrapperEMA will initially set generator_model_ema’s weights to generator_model’s weights.
Next step is going to be needed because if we compile train_step method to optimize its execution time, we won’t be able to access model’s weights property inside ema_update function. So we’ll need to call ema_update outside train_step, by implementing a custom training loop.
To implement a custom training loop, we’ll first define two utility functions to help us keeping track of model’s stats:
def update_dict(orig_dict, new_dict):
for key in orig_dict.keys():
orig_dict[key].append(new_dict[key])
return orig_dict
def print_stats(epoch, step, stats_dict):
print("Epoch: {}, Step: {}, {}".format(epoch, step, {key: tf.reduce_mean(value).numpy() for key, value in stats_dict.items()}))
And finally implement the loop itself, which will just execute train_step, updating the original generator_model via a gradient optimizer. Additionally, every ema_every steps, we’ll perform an ema_step, which will update our generator_model_ema internally using EMA.
# Print params
print_output_every = 400
# Train the model
for epoch in range(n_epochs):
image_save_callback.on_epoch_begin(epoch)
stats_dict = {"d_loss": [], "g_loss": [], "D(G(z))": [], "D(x)": []}
for step, x in enumerate(training_generator):
train_details = gan_model.train_step(x)
update_dict(stats_dict, train_details)
if gan_model.current_step > ema_start and gan_model.current_step % ema_every == 0:
gan_model.ema_step()
if step % print_output_every == 0:
print_stats(epoch, step, stats_dict)
# Note: we'll need to save weights manually,
# as ModelCheckpoint callback will fail
# to save custom Tensorflow models.
gan_model.save_weights("saved_models/model_{}.h5".format(epoch))
Running our code, as previously, for 30 epochs, should produce a better looking result similar to the one below:
Let us also show attribute manipulation with same conditions as before.
Base:
Blonde:
Blonde and smiling:
Leave a comment