Generating faces with specific attributes

17 minute read

The problem of conditional image generation consists in generating a set of images, given a specific set of attributes, such as a particular color or visual feature.

In our case, we are interested in generating human faces, given a specific set of attributes found in CelebA dataset. In order to do so, we are going to use a particular kind of generative models called generative adversarial networks (GANs).

The mathematical details of how GANs work will be omitted in this post, which will instead focus on the programming side of our task.

GANs

Evolution of GAN generated samples in last couple of years. Credits to Ian Goodfellow.

Generative adversarial networks, are currently one of the most used approaches for image generation, as they allow to (better say “try to”) generate sharp and realisticly looking images.

The realistic look of the images comes from the training architecture, which employs two neural networks, generator and discriminator, that are trained alternatively.

The role of generator consists in generating images from a given input noise, while the discriminator is used to classify the incoming images as being real (coming from training dataset) or being fake (generated by generator). The more the training goes on, the more the generator will try to improve at generating better looking samples.

The problem with all this however, is that one of these two components might train better than the other, causing instability, yielding not desired images.

Preparations

Preparing environment

As our programming language we are going to use Python3, with Tensorflow 2.2 (or higher) and it’s high level Keras API with some additional libraries.

Note: We’ll use Anaconda that will allow us to easily install all the dependencies, as Tensorflow GPU has some annoying manual dependencies.

It is a good and useful practice to install the libraries into a separate virtual environment (to avoid messing the global libraries), thus let us use Anaconda to create an environment named tf2_faces and install all the needed libraries:

conda create --name tf2_faces

Next, we’ll activate the environment and install the necessary dependencies:

conda activate tf2_faces
conda install pip
conda install -c anaconda tensorflow-gpu
conda install -c anaconda pandas
conda install -c conda-forge opencv
conda install -c conda-forge matplotlib 

Now our environment is ready to be used, however we still miss our training data.

Downloading data

We are going to use CelebA dataset, which contains around 200,000 face images annotated with 40 attributes, (such as beard/nobeard, hair/nohair etc.).

CelebA dataset preview from the original authors site.

For our task, make sure to download Anno/list_attr_celeba.txt and Img/img_align_celeba.zip files from the following CelebA drive folder.

Next, extract the images zip archive in a folder and put the list_attr_celeba.txt file in the same folder:

mkdir CelebA
mv img_align_celeba.zip CelebA/img_align_celeba.zip
mv list_attr_celeba.txt CelebA/list_attr_celeba.txt
cd CelebA
unzip img_align_celeba.zip

Our data is now also ready to be used.

Coding

Here we are going to present our code for loading data, defining and training our model.

Data loading and preview

Let us first define a function to load a single image, crop, resize to 64x64 and normalize it into [-1, 1] range, as usually done in GANs:

import cv2
import numpy as np

def load_image(image_path, resize_size, crop_pt_1=None, crop_pt_2=None):
    """
    Loads an image from a given path, crops it around the rectangle defined by two points, and resizes it.

    Inputs:
    image_path: path of the image
    resize_size: new size of the output image
    crop_pt_1: first point of crop
    crop_pt_2: second point of crop

    Returns:
    image: resized image and rescaled to [-1, 1] interval
    """
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    shape = image.shape
    if crop_pt_1 is None:
        crop_pt_1 = (0, 0)
    if crop_pt_2 is None:
        crop_pt_2 = (shape[0], shape[1])
    image = image[crop_pt_1[0]:crop_pt_2[0], crop_pt_1[1]:crop_pt_2[1]]
    resized = cv2.resize(image, resize_size)
    resized = resized.astype(np.float32)
    return (resized - 127.5) / 127.5

This function will be useful as it will do all the necessary preprocessing for our data.

Then we’ll need to define a class, subclass of tensorflow.keras.utils.Sequence, that will easily allow us to dynamically load batches of images in a memory-efficient way (i.e. avoid loading all the image data at once):

import math
import os
import random
import numpy as np
from tensorflow.keras.utils import Sequence

class DataSequence(Sequence):
    """
    Keras Sequence object to train a model on larger-than-memory data.
    """
    def __init__(self, df, data_root, batch_size, resize_size=(64, 64), flip_augment=True, mode='train'):
        self.df = df
        self.batch_size = batch_size
        self.mode = mode
        self.resize_size = resize_size
        self.crop_pt_1 = (45, 25)
        self.crop_pt_2 = (173, 153)
        self.flip_augment = flip_augment
        # Extract columns from df columns
        self.label_columns = self.df.columns[1:].tolist() 

        # Take labels and a list of image locations in memory
        self.labels = self.df[self.label_columns].values
        self.im_list = self.df['Image_Name'].apply(lambda x: os.path.join(data_root, x)).tolist()
        # Trigger a shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(math.floor(len(self.df) / float(self.batch_size)))

    def on_epoch_end(self):
        # Shuffles indexes after each epoch if in training mode
        self.indexes = range(len(self.im_list))
        if self.mode == 'train':
            self.indexes = random.sample(self.indexes, k=len(self.indexes))

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return self.labels[idx]

    def get_batch_features(self, idx):
        images = []
        for im_idx in idx:
            im = self.im_list[im_idx]
            loaded_image = load_image(im, self.resize_size, self.crop_pt_1, self.crop_pt_2)
            if self.flip_augment and random.random() < 0.5:
                loaded_image = np.flip(loaded_image, 1)
            images.append(loaded_image)
        # Fetch a batch of inputs
        return np.array(images)

    def __getitem__(self, index):
        idx = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Get the actual data
        batch_x = self.get_batch_features(idx)
        batch_y = np.clip(self.get_batch_labels(idx).astype(np.float32), 0, 1)
        return (batch_x, batch_y), batch_y

Note: for a better explaination what each DataSequence’s code part, check this post by Afshine Amidi and Shervine Amidi.

In order to test the above code, we first define a function that will allow us to visualize image data as a grid:

import matplotlib.pyplot as plt

def display_images(image_batch, cols=4, rows=8):
    height, width = image_batch.shape[1:3]
    reshaped = np.reshape(image_batch, (rows, cols, height, width, 3)).transpose(0, 2, 1, 3, 4).reshape(rows*height, cols * width, 3)
    plt.imshow(reshaped)
    plt.axis('off')
    plt.show()

Finally, we can test our DataSequence by loading a batch of images and displaying it using our freshly defined display_images function:

import pandas as pd

train_data_df = pd.read_csv("CelebA/list_attr_celeba.txt")
training_generator = DataSequence(train_data_df, "CelebA/img_align_celeba", batch_size=32)

(images, labels), labels = next(iter(training_generator))
display_images(images)

This should produce a similar figure with 32 faces, whose intensity got scaled to be in [-1, 1] range (hence why they look dark):

We are now ready to define our model.

Model definition

The model we’ll be using, is simple DCGAN model with pixelwise normalization, spectral normalization and hinge loss, which will allow us stabilize the training, as one of the main GAN problems is the training stability (i.e. they might start generating nonsense images as training proceeds).

Custom Layers

First we need to define custom Keras layers with spectral normalization:

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.python.keras.utils import conv_utils

#epsilon set according to BIGGAN https://arxiv.org/pdf/1809.11096.pdf
def _l2normalizer(v, epsilon=1e-4):
    return v / (K.sum(v**2)**0.5 + epsilon)


def power_iteration(W, u, rounds=1):
    '''
    Accroding the paper, we only need to do power iteration one time.
    '''
    _u = u

    for i in range(rounds):
        _v = _l2normalizer(K.dot(_u, W))
        _u = _l2normalizer(K.dot(_v, K.transpose(W)))

    W_sn = K.sum(K.dot(_u, W) * _v)
    return W_sn, _u, _v

The above code performs imports and defines an utility method, while below we define the actual layers for spectral normalized 2D convolution:

class SNConv2D(tf.keras.layers.Conv2D):
    def __init__(self, filters, spectral_normalization=True, **kwargs):
        self.spectral_normalization = spectral_normalization
        super(SNConv2D, self).__init__(filters, **kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.

        self.u = self.add_weight(name='u', shape=(1, self.filters),
                                 initializer='uniform', trainable=False)
        super(SNConv2D, self).build(input_shape)
        # Be sure to call this at the end

    def compute_spectral_normal(self, training=True):
        # Spectrally Normalized Weight
        if self.spectral_normalization:
            # Get kernel tensor shape [kernel_h, kernel_w, in_channels, out_channels]
            W_shape = self.kernel.shape.as_list()

            # Flatten the Tensor
            # [out_channels, N]
            W_mat = K.reshape(self.kernel, [W_shape[-1], -1])

            W_sn, u, v = power_iteration(W_mat, self.u)

            if training:
                # Update estimated 1st singular vector
                self.u.assign(u)

            return self.kernel / W_sn
        else:
            return self.kernel

    def call(self, inputs, training=None):

        outputs = K.conv2d(inputs,
                           self.compute_spectral_normal(training=training),
                           strides=self.strides, padding=self.padding,
                           data_format=self.data_format,
                           dilation_rate=self.dilation_rate)

        if self.use_bias:
            outputs = K.bias_add(outputs, self.bias,
                                 data_format=self.data_format)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    def compute_output_shape(self, input_shape):
        return super(SNConv2D, self).compute_output_shape(input_shape)

and spectrally normalized 2D deconvolution:

class SNConv2DTranspose(tf.keras.layers.Conv2DTranspose):
    def __init__(self, spectral_normalization=True, **kwargs):
        self.spectral_normalization = spectral_normalization
        super(SNConv2DTranspose, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.

        self.u = self.add_weight(name='u', shape=(1, self.filters),
                                 initializer='uniform', trainable=False)
        super(SNConv2DTranspose, self).build(input_shape)
        # Be sure to call this at the end

    def compute_spectral_normal(self, training=True):
        # Spectrally Normalized Weight
        if self.spectral_normalization:
            # Get kernel tensor shape [kernel_h, kernel_w, in_channels, out_channels]
            W_shape = self.kernel.shape.as_list()

            # Flatten the Tensor
            # [out_channels, N]
            W_mat = K.reshape(self.kernel, [W_shape[-2], -1])

            W_sn, u, v = power_iteration(W_mat, self.u)

            if training:
                # Update estimated 1st singular vector
                self.u.assign(u)

            return self.kernel / W_sn
        else:
            return self.kernel

    def call(self, inputs, training=None):
        input_shape = K.shape(inputs)
        batch_size = input_shape[0]
        if self.data_format == 'channels_first':
            h_axis, w_axis = 2, 3
        else:
            h_axis, w_axis = 1, 2

        height, width = input_shape[h_axis], input_shape[w_axis]
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides
        out_pad_h = out_pad_w = None

        # Infer the dynamic output shape:
        out_height = conv_utils.deconv_output_length(
            height, kernel_h, self.padding, stride=stride_h)
        out_width = conv_utils.deconv_output_length(
            width, kernel_w, self.padding, stride=stride_w)
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_height, out_width)
        else:
            output_shape = (batch_size, out_height, out_width, self.filters)

        outputs = K.conv2d_transpose(
            inputs,
            self.compute_spectral_normal(training=training),
            output_shape,
            self.strides,
            padding=self.padding,
            data_format=self.data_format
        )

        if self.use_bias:
            outputs = K.bias_add(outputs, self.bias,
                                 data_format=self.data_format)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    def compute_output_shape(self, input_shape):
        return super(SNConv2DTranspose, self).compute_output_shape(input_shape)

Don’t mind the overwhelming code, as it is just a wrapper around the existing Keras’s 2D convolution and deconvolution implementations.

Next, we’ll define PixelwiseNormalization layer, which was first defined in ProgressiveGAN to add an additional source of stabilization:

class PixelNormalization(tf.keras.layers.Layer):
	def __init__(self, **kwargs):
		super(PixelNormalization, self).__init__(**kwargs)

	def call(self, inputs):
		# Calculate square pixel values
		values = inputs**2.0
		# Calculate the mean pixel values
		mean_values = tf.keras.backend.mean(values, axis=-1, keepdims=True)
		# Ensure the mean is not zero
		mean_values += 1.0e-8
		# Calculate the sqrt of the mean squared value (L2 norm)
		l2 = tf.keras.backend.sqrt(mean_values)
		# Normalize values by the l2 norm
		normalized = inputs / l2
		return normalized

	def compute_output_shape(self, input_shape):
		return input_shape

Models

We finally can define our GAN model’s components: generator, discriminator, wrapper model.

Generator and discriminator models follow a standard DCGAN architecture with the addition of pixelwise normalization layers instead of batch normalization and spectrally normalized convolutions.

Our generator will have an additional input vector for image conditions:

class Generator(tf.keras.Model):
    def __init__(self, latent_dim, filters=64, kernel_size=4, cond_dim=40):
        super(Generator, self).__init__()

        self.filters = filters
        self.latent_dim = latent_dim

        self.dense1 = tf.keras.layers.Dense(4 * 4 * filters * 8, input_shape=(latent_dim + cond_dim,))
        self.relu1 = tf.keras.layers.ReLU()
        self.reshape1 = tf.keras.layers.Reshape((4, 4, filters * 8))
        self.bn1 = PixelNormalization()
        # 4x4 -> 8x8
        self.block1_upscale = tf.keras.layers.UpSampling2D()
        self.block1_conv1 = SNConv2D(
            filters=filters * 4, kernel_size=kernel_size, strides=(1, 1), 
            padding="same")
        self.block1_bn = PixelNormalization()
        self.block1_relu1 = tf.keras.layers.ReLU()
        # 8x8 -> 16x16
        self.block2_upscale = tf.keras.layers.UpSampling2D()
        self.block2_conv1 = SNConv2D(
            filters=filters * 2, kernel_size=kernel_size, strides=(1, 1), 
            padding="same")
        self.block2_bn = PixelNormalization()
        self.block2_relu1 = tf.keras.layers.ReLU()
        # 16x16 -> 32x32
        self.block3_upscale = tf.keras.layers.UpSampling2D()
        self.block3_conv1 = SNConv2D(
            filters=filters, kernel_size=kernel_size, strides=(1, 1), 
            padding="same")
        self.block3_bn = PixelNormalization()
        self.block3_relu1 = tf.keras.layers.ReLU()
        # 32x32 -> 64x64
        self.block4_upscale = tf.keras.layers.UpSampling2D()
        self.block4_conv1 = SNConv2D(
            filters=filters // 2, kernel_size=kernel_size, strides=(1, 1), 
            padding="same")
        self.block4_bn = PixelNormalization()
        self.block4_relu1 = tf.keras.layers.ReLU()
        # 64 x 64 x FILTERS -> 64 x 64 x 3
        self.block5_conv1 = SNConv2D(
            filters=3, kernel_size=4, strides=(1, 1), 
            padding="same", activation="tanh")

    @tf.function
    def call(self, z, conditions, training=False):

        x = tf.concat([z, conditions], axis=-1)

        x = self.dense1(x)
        x = self.relu1(x)
        x = self.reshape1(x)
        x = self.bn1(x, training=training)

        x = self.block1_upscale(x)
        x = self.block1_conv1(x)
        x = self.block1_bn(x, training=training)
        x = self.block1_relu1(x)

        x = self.block2_upscale(x)
        x = self.block2_conv1(x)
        x = self.block2_bn(x, training=training)
        x = self.block2_relu1(x)

        x = self.block3_upscale(x)
        x = self.block3_conv1(x)
        x = self.block3_bn(x, training=training)
        x = self.block3_relu1(x)

        x = self.block4_upscale(x)
        x = self.block4_conv1(x)
        x = self.block4_bn(x, training=training)
        x = self.block4_relu1(x)

        images = self.block5_conv1(x)
        return images

Discriminator will output both real/generated score and a classification score for CelebA attributes:

class Discriminator(tf.keras.Model):
    def __init__(self, filters=64, kernel_size=3, n_classes=40):
        super(Discriminator, self).__init__()

        self.filters = filters

        # 64 x 64 x FILTERS
        self.block1_conv1 = SNConv2D(filters, kernel_size=kernel_size, strides=(2, 2), padding="same")
        self.block1_lrelu1 = tf.keras.layers.LeakyReLU(alpha=0.1)

        # 32 x 32 x FILTERS
        self.block2_conv1 = SNConv2D(filters * 2, kernel_size=kernel_size, strides=(2, 2), padding="same")
        self.block2_lrelu1 = tf.keras.layers.LeakyReLU(alpha=0.1)

        # 16 x 16 x FILTERS
        self.block3_conv1 = SNConv2D(filters * 4, kernel_size=kernel_size, strides=(2, 2), padding="same")
        self.block3_lrelu1 = tf.keras.layers.LeakyReLU(alpha=0.1)

        # 8 x 8 x FILTERS
        self.block4_conv1 = SNConv2D(filters * 8, kernel_size=kernel_size, strides=(2, 2), padding="same")
        self.block4_lrelu1 = tf.keras.layers.LeakyReLU(alpha=0.1)

        # Current size: 4 x 4 x FILTERS 
        self.avg_pool = tf.keras.layers.GlobalAvgPool2D()

        self.scoring = tf.keras.layers.Dense(1)
        self.classifier = tf.keras.layers.Dense(n_classes, activation="sigmoid")
    
    @tf.function
    def call(self, images, training=False):
        x = images

        x = self.block1_conv1(x)
        x = self.block1_lrelu1(x)

        x = self.block2_conv1(x)
        x = self.block2_lrelu1(x)

        x = self.block3_conv1(x)
        x = self.block3_lrelu1(x)

        x = self.block4_conv1(x)
        x = self.block4_lrelu1(x)

        x = self.avg_pool(x)

        scores = self.scoring(x)
        classes = self.classifier(x)

        return scores, classes

The wrapper model is used with custom train_step to allow training via Keras fit method, as shown in DCGAN Keras example:

class GAN_Wrapper(tf.keras.Model):

    def __init__(self, discriminator, generator, **kwargs):
        super(GAN_Wrapper, self).__init__(**kwargs)

        self.discriminator = discriminator
        self.generator = generator

        self.latent_dim = self.generator.latent_dim

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(GAN_Wrapper, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_classifier_loss = tf.keras.losses.BinaryCrossentropy()
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def call(self, x, training=False):
        """
        This method is overridden only because it is required by tf.keras.Model
        """
        pass

    @tf.function
    def train_step(self, data):
        # Unpack images and their labels
        (real_images, conditions), _ = data

        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]

        # Train the discriminator
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images from the latent vector
            fake_images = self.generator(random_latent_vectors, conditions, training=False)
            # Get the logits for the fake images
            fake_logits, fake_classes = self.discriminator(fake_images, training=True)
            # Get the logits for real images
            real_logits, real_classes = self.discriminator(real_images, training=True)
            # Calculate discriminator loss using fake and real logits
            d_cost = self.d_loss_fn(logits_real=real_logits, logits_fake=fake_logits)
            d_class_real = self.d_classifier_loss(y_true=conditions, y_pred=real_classes)
            # Add the class loss to the original discriminator loss
            d_loss = d_cost + d_class_real

        # Get the gradients w.r.t the discriminator loss
        d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
        # Update the weights of the discriminator using the discriminator optimizer
        self.d_optimizer.apply_gradients(
            zip(d_gradient, self.discriminator.trainable_variables)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Train the generator
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(random_latent_vectors, conditions, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits, gen_img_classes = self.discriminator(generated_images, training=False)
            # Calculate the generator loss
            g_cost = self.g_loss_fn(gen_img_logits)
            g_class_gen = self.d_classifier_loss(y_true=conditions, y_pred=gen_img_classes)
            g_loss = g_cost + g_class_gen
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        return {
            "d_loss": d_loss, 
            "g_loss": g_loss, 
            "D(G(z))": tf.reduce_mean(gen_img_logits), 
            "D(x)": tf.reduce_mean(real_logits)
        }

Having defined our model and image data loader, we can finally proceed to train it.

Training

First, we declare our training parameters:

# Fit parameters
batch_size = 32
n_epochs = 30
# Generator parameters
latent_dim = 128
conditional_dim = 40
filters_gen = 64
kernel_size_gen = 3
# Discriminator parameters
filters_disc = 64
kernel_size_disc = 3

Then, we instantiate Discriminator, Generator and GAN_Wrapper objects:

# Create models
generator_model = Generator(latent_dim, filters=filters_gen, kernel_size=kernel_size_gen)
discriminator_model = Discriminator(filters=filters_disc, kernel_size=kernel_size_disc)
# Create gan wrapper model
gan_model = GAN_Wrapper(discriminator_model, generator_model)

the hinge loss:

@tf.function
def discriminator_loss(logits_real, logits_fake):
    real_loss = tf.keras.backend.mean(tf.keras.backend.relu(1 - logits_real))
    fake_loss = tf.keras.backend.mean(tf.keras.backend.relu(1 + logits_fake))
    return (fake_loss + real_loss) / 2

@tf.function
def generator_loss(logits_fake):
    return -tf.keras.backend.mean(logits_fake)

and compile:

gan_model.compile(
    d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.00005, beta_1=0.0),
    g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.0),
    d_loss_fn=discriminator_loss,
    g_loss_fn=generator_loss,
)

Finally, we can perform the actual fitting of our GAN:

history = gan_model.fit(training_generator,
    use_multiprocessing=True,
    workers=8,
    epochs=n_epochs,
    callbacks=[
        tf.keras.callbacks.ModelCheckpoint(filepath='saved_models/model_{epoch}.h5')
    ]
)

The above code, will allow us to train a GAN model to generate human faces, having specific visual attributes.

Notice that we used ModelCheckpoint callback, in order to save our model every epoch. If we also wanted to save images, we could add the following callback, that generates images with a given conditions set, to the callbacks list:

ImagesLoggingCallback(generator_model, latent_dim, labels, "images")

Where ImagesLoggingCallback is defined as:

class ImagesLoggingCallback(tf.keras.callbacks.Callback):

    def __init__(self, generator, latent_dim, view_conditions, images_dir, n_images=32, rows=4, cols=8):
        super(ImagesLoggingCallback, self).__init__()
        # Note: we can access self.model.generator also
        self.generator = generator
        self.n_images = n_images
        self.rows = rows
        self.cols = cols
        self.latent_dim = latent_dim
        self.images_dir = images_dir
        self.view_cond = view_conditions
        self.random_latent_vectors = tf.random.normal(shape=(self.n_images, self.latent_dim))

    def on_epoch_begin(self, epoch, logs=None):
        # Generate images
        generated_images = self.generator(self.random_latent_vectors, self.view_conditions, training=False)
        generated_images = (generated_images + 1) / 2.0
        generated_images.numpy()
        # Save the figure
        fig = plt.figure()
        height, width = generated_images.shape[1:3]
        reshaped = np.reshape(generated_images, (self.rows, self.cols, height, width, 3)).transpose(0, 2, 1, 3, 4).reshape(self.rows*height, self.cols * width, 3)
        plt.imshow(reshaped)
        plt.axis('off')
        fig.savefig(self.images_dir+"/sample_{}.png".format(epoch))
        plt.close(fig)

Results

The overall training should take around an hour on a GTX1080 GPU, and produce results similar to the following:

We can play with conditional labels, and make every generated sample blonde:

And to make everyone smile:

Improvements

As seen in the images above, our model can generate faces that resemble human ones, however most of them look unrealistic.

One possible solution might always be to train the model more, as we’ve trained for only 30 epochs.

Another, simple solution, is the one we are going to describe next.

EMA

Instead of a training a single generator model, whose updates might still cause some instability (producing not so smooth images), we can jointly train a second generator, that will be updated via exponential moving average (EMA) rather than with gradient methods.

More precisely, let us G be a generator model with weights u, our training procedure will:

  1. Clone the generator model, producing G’.
  2. Update G’ weights using EMA: wt+1 = (1 - b) * ut + (b) * wt

The value of b is chosen such that 1 - b quantity will be close to zero, for example b=0.9999.

Let us define an EMA update function, which will take two models as input, and update the second one using a small fraction of first’s weights:

def ema_update(model, model_ema, beta=0.9999):
    """
    Performs a model update by using exponential moving average (EMA) 
    of first model's weights to update second model's weights 
    as defined in "The Unusual Effectiveness of Averaging in GAN Training" (https://arxiv.org/abs/1806.04498), 
    realizing the following update:
    
    model_ema.weights = beta * model_ema.weights + (1 - beta) * model.weights
    
    :param model: original, gradient descent trained model.
    :param model_ema: clone of original model, that will get updated using EMA.
    :param beta: EMA update weight, determines what portion of model_ema weights should be kept (default=0.9999).
    """
    # for each model layer index
    for i in range(len(model.layers)):
        updating_weights = model.layers[i].get_weights() # original model's weights
        ema_old_weights = model_ema.layers[i].get_weights() # ema model's weights
        ema_new_weights = [] # ema model's update weights
        if len(updating_weights) != len(ema_old_weights):
            # weight list length mismatch between model's weights and ema model's weights
            print("Different weight length")
            # copy ema weights directly from the model's weights
            ema_new_weights = updating_weights
        else:
            # for each weight tensor of original model's weights list
            for j in range(len(updating_weights)):
                n_weight = beta * ema_old_weights[j] + (1 - beta) * updating_weights[j]
                ema_new_weights.append(n_weight)
        # update weights
        model_ema.layers[i].set_weights(ema_new_weights)

Then the actual wrapper for our EMA model, which will be a subclass of the original GAN_Wrapper:

class GAN_WrapperEMA(GAN_Wrapper):

    def __init__(self, discriminator, generator, generator_ema, beta=0.9999):
        super(GAN_WrapperEMA, self).__init__(discriminator, generator)
        self.generator_ema = generator_ema
        self.generator_ema.set_weights(self.generator.get_weights())

        self.beta = beta

        self.current_step = tf.Variable(initial_value=0, trainable=False)

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(GAN_WrapperEMA, self).compile(d_optimizer, g_optimizer, d_loss_fn, g_loss_fn)

    def ema_step(self):
        ema_update(self.generator, self.generator_ema, beta=self.beta)

    def call(self, x, training=False):
        """
        This method is overridden only because it is required by tf.keras.Model
        """
        pass

    @tf.function
    def train_step(self, data):
        stats_dict = super(GAN_WrapperEMA, self).train_step(data)
        self.current_step.assign_add(1)
        return stats_dict

Now, we need to instantiate again all of the objects:

# EMA params
ema_beta = 0.9999
ema_start = 10
ema_every = 2

# create models
generator_model = Generator(latent_dim, filters=filters_gen, kernel_size=kernel_size_gen)
generator_model_ema = Generator(latent_dim, filters=filters_gen, kernel_size=kernel_size_gen)
discriminator_model = Discriminator(filters=filters_disc, kernel_size=kernel_size_disc)
# create gan wrapper model
gan_model = GAN_WrapperEMA(discriminator_model, generator_model, generator_model_ema, beta=ema_beta)
# images callback
image_save_callback = ImagesLoggingCallback(generator_model_ema, latent_dim, labels, "images")
# compile model
gan_model.compile(
    d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.00005, beta_1=0.0),
    g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.0),
    d_loss_fn=discriminator_loss,
    g_loss_fn=generator_loss,
)

As we can see from the code definitions, GAN_WrapperEMA will initially set generator_model_ema’s weights to generator_model’s weights.

Next step is going to be needed because if we compile train_step method to optimize its execution time, we won’t be able to access model’s weights property inside ema_update function. So we’ll need to call ema_update outside train_step, by implementing a custom training loop.

To implement a custom training loop, we’ll first define two utility functions to help us keeping track of model’s stats:

def update_dict(orig_dict, new_dict):
    for key in orig_dict.keys():
        orig_dict[key].append(new_dict[key])
    return orig_dict

def print_stats(epoch, step, stats_dict):
    print("Epoch: {}, Step: {}, {}".format(epoch, step, {key: tf.reduce_mean(value).numpy() for key, value in stats_dict.items()}))

And finally implement the loop itself, which will just execute train_step, updating the original generator_model via a gradient optimizer. Additionally, every ema_every steps, we’ll perform an ema_step, which will update our generator_model_ema internally using EMA.

# Print params
print_output_every = 400
# Train the model
for epoch in range(n_epochs):
    image_save_callback.on_epoch_begin(epoch)
    stats_dict = {"d_loss": [], "g_loss": [], "D(G(z))": [], "D(x)": []}
    for step, x in enumerate(training_generator):
        train_details = gan_model.train_step(x)
        update_dict(stats_dict, train_details)
        if gan_model.current_step > ema_start and gan_model.current_step % ema_every == 0:
            gan_model.ema_step()
        if step % print_output_every == 0:
            print_stats(epoch, step, stats_dict)
    # Note: we'll need to save weights manually, 
    # as ModelCheckpoint callback will fail 
    # to save custom Tensorflow models.            
    gan_model.save_weights("saved_models/model_{}.h5".format(epoch))

Running our code, as previously, for 30 epochs, should produce a better looking result similar to the one below:

Let us also show attribute manipulation with same conditions as before.

Base:

Blonde:

Blonde and smiling:

Updated:

Leave a comment