How to implement GAN with Keras

Posted on May 28, 2017

This tutorial is to guide you how to implement GAN with Keras. The complete code can be access in my github repository. If you are not familiar with GAN, please check the first part of this post or another blog to get the gist of GAN.

Generator

The generator is used to generate images from noise. Following DCGAN, one feasible generator in GAN/models/gen.py looks like :

def basic_gen(input_shape, img_shape, nf=128, scale=4, FC=[], use_upsample=False):
    dim, h, w = img_shape 

    img = Input(input_shape)
    x = img
    for fc_dim in FC: 
        x = Dense(fc_dim)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    x = Dense(nf*2**(scale-1)*(h/2**scale)*(w/2**scale))(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Reshape((nf*2**(scale-1), h/2**scale, w/2**scale))(x)

    for s in range(scale-2, -1, -1):
        # up sample can elimiate the checkbroad artifact
        # http://distill.pub/2016/deconv-checkerboard/
        if use_upsample:
            x = UpSampling2D()(x)
            x = Conv2D(nf*2**s, (3,3), padding='same')(x)
        else:
            x = Deconv2D(nf*2**s, (3, 3), strides=(2, 2), padding='same')(x) 
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    if use_upsample:
        x = UpSampling2D()(x)
        x = Conv2D(dim, (3, 3), padding='same')(x)
    else:
        x = Deconv2D(dim, (3, 3), strides=(2, 2), padding='same')(x) 

    x = Activation('tanh')(x)

    return Model(img, x)
  • The scale means how many times the image will be scaled up.
  • FC represents the fully connected network before sampling up.
  • use_upsample is an option to use upsample+conv to replace deconv to get a better image quality. Please check this post for details.
  • I use tanh as activation function, which also means real images should be normalized to [-1, 1].

Discriminator

The discriminator is simpler than a generator, in GAN/models/dis.py I implemented a very simple discriminator:

def basic_dis(input_shape, nf=128, scale=4, FC=[], bn=True):
    dim, h, w = input_shape

    img = Input(input_shape)
    x = img
    
    for s in range(scale):
        x = Conv2D(nf*2**s, (5, 5), strides=(2, 2), padding='same')(x)
        if bn: x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)

    x = Flatten()(x)
    for fc in FC:
        x = Dense(fc)(x)
        if bn: x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)
    x = Dense(1, activation='sigmoid')(x)

    return Model(img, x)

GAN

Code of GAN is in GAN/models/GAN.py.

To train the generator, we first have to connect it with discriminator by

gendis = gan = Sequential([generator, discriminator])

For simplicity, we use gen to refer to the generator and dis to refer to the discriminator. During the training of gen, the weights of dis should be fixed. We can enforce that by setting the trainable flag before compiling.

dis.trainable = False 
gendis.compile(optimizer=opt, loss='binary_crossentropy')

To train the discriminator, we should separate the real and fake images as suggested in this blog. If you mix two types of images in one batch, you might find the discrepancy of features in batch normalization layers is too significant, which makes the discrimination so simple that nothing is learned. 

In my code, I use a two-branch-model to forward two types of images separately:

shape = dis.get_input_shape_at(0)[1:]
gen_input, real_input = Input(shape), Input(shape)
dis2batch = Model([gen_input, real_input], [dis(gen_input), dis(real_input)])

dis.trainable = True
dis2batch.compile(optimizer=opt, loss='binary_crossentropy', metrics=['binary_accuracy'])

To sum up:

gen_trainner = gendis
dis_trainner = dis2batch

Training

The simplest code to train GAN might look like this:

for iteration in range(1, niter+1):
    print 'iteration', iteration
    real_img = data_generator(nbatch)
    Z = np.random.uniform(-1., 1., size=(nbatch, self.coding)).astype('float32')
    gen_img = self.generate(Z)

    y = np.ones((nbatch, 1))
    g_loss = self.gen_trainner.train_on_batch(Z, y)

    gen_y   = np.zeros((nbatch, 1))
    real_y  = np.ones((nbatch, 1))
    d_loss = self.dis_trainner.train_on_batch([gen_img, real_img], [gen_y, real_y])

Suggested by many works, the discriminator should be trained much more times than the generator. To do that, we can modify the code:

k = 3 # a constant, how many times we train dis before training gen
for iteration in range(1, niter+1):
    ...
    if iteration % (k+1) ==0:
        # train dis
    else:
        # train gen

Also, to improve the stability of training, we can collect recent fake/generated images in a pool and sample the network inputs from it:

fake_pool = []
for iteration in range(1, niter+1):
    ...

    # previous version:
    #     gen_img = self.generate(Z)

    # current version:
    fake_pool.extend(self.generate(Z))
    fake_pool = fake_pool[-pool_size:]
    gen_img = np.array(fake_pool)[np.random.choice(len(fake_pool), size=(nbatch,), replace=False)]

    ...

 

 

Postscript:

Other training tricks can be found in here.

Based on this framework, I also implemented Conditional GANInfoGAN, and other variety of GAN with Keras-1.x API in legacy/. Although those codes are not maintained and might not even working with Keras-2.x API, you can check them for inspiration.

I used Theano before Keras and it was taxing to build a deep neural network with raw Theano, even with Tensorflow. I implemented some keras-like libraries but it was not as well organized and forward-looking as Keras. So I instead turned to the Keras Community.

None