# How to implement GAN with Keras

Posted on May 28, 2017

This tutorial is to guide you how to implement GAN with Keras. The complete code can be access in my github repository. If you are not familiar with GAN, please check the first part of this post or another blog to get the gist of GAN.

### Generator

The generator is used to generate images from noise. Following DCGAN, one feasible generator in `GAN/models/gen.py` looks like :

``````def basic_gen(input_shape, img_shape, nf=128, scale=4, FC=[], use_upsample=False):
dim, h, w = img_shape

img = Input(input_shape)
x = img
for fc_dim in FC:
x = Dense(fc_dim)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

x = Dense(nf*2**(scale-1)*(h/2**scale)*(w/2**scale))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Reshape((nf*2**(scale-1), h/2**scale, w/2**scale))(x)

for s in range(scale-2, -1, -1):
# up sample can elimiate the checkbroad artifact
# http://distill.pub/2016/deconv-checkerboard/
if use_upsample:
x = UpSampling2D()(x)
else:
x = Deconv2D(nf*2**s, (3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

if use_upsample:
x = UpSampling2D()(x)
x = Conv2D(dim, (3, 3), padding='same')(x)
else:
x = Deconv2D(dim, (3, 3), strides=(2, 2), padding='same')(x)

x = Activation('tanh')(x)

return Model(img, x)
``````
• The `scale `means how many times the image will be scaled up.
• `FC` represents the fully connected network before sampling up.
• `use_upsample` is an option to use `upsample+conv `to replace `deconv `to get a better image quality. Please check this post for details.
• I use `tanh `as activation function, which also means real images should be normalized to $[-1, 1]$.

### Discriminator

The discriminator is simpler than a generator, in `GAN/models/dis.py `I implemented a very simple discriminator:

``````def basic_dis(input_shape, nf=128, scale=4, FC=[], bn=True):
dim, h, w = input_shape

img = Input(input_shape)
x = img

for s in range(scale):
x = Conv2D(nf*2**s, (5, 5), strides=(2, 2), padding='same')(x)
if bn: x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)

x = Flatten()(x)
for fc in FC:
x = Dense(fc)(x)
if bn: x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Dense(1, activation='sigmoid')(x)

return Model(img, x)``````

### GAN

Code of GAN is in `GAN/models/GAN.py`.

To train the generator, we first have to connect it with discriminator by

``gendis = gan = Sequential([generator, discriminator])``

For simplicity, we use `gen `to refer to the generator and `dis `to refer to the discriminator. During the training of `gen`, the weights of `dis `should be fixed. We can enforce that by setting the `trainable `flag before compiling.

``````dis.trainable = False
gendis.compile(optimizer=opt, loss='binary_crossentropy')
``````

To train the discriminator, we should separate the real and fake images as suggested in this blog. If you mix two types of images in one batch, you might find the discrepancy of features in batch normalization layers is too significant, which makes the discrimination so simple that nothing is learned.

In my code, I use a two-branch-model to forward two types of images separately:

``````shape = dis.get_input_shape_at(0)[1:]
gen_input, real_input = Input(shape), Input(shape)
dis2batch = Model([gen_input, real_input], [dis(gen_input), dis(real_input)])

dis.trainable = True
dis2batch.compile(optimizer=opt, loss='binary_crossentropy', metrics=['binary_accuracy'])
``````

To sum up:

``````gen_trainner = gendis
dis_trainner = dis2batch
``````

### Training

The simplest code to train GAN might look like this:

``````for iteration in range(1, niter+1):
print 'iteration', iteration
real_img = data_generator(nbatch)
Z = np.random.uniform(-1., 1., size=(nbatch, self.coding)).astype('float32')
gen_img = self.generate(Z)

y = np.ones((nbatch, 1))
g_loss = self.gen_trainner.train_on_batch(Z, y)

gen_y   = np.zeros((nbatch, 1))
real_y  = np.ones((nbatch, 1))
d_loss = self.dis_trainner.train_on_batch([gen_img, real_img], [gen_y, real_y])``````

Suggested by many works, the discriminator should be trained much more times than the generator. To do that, we can modify the code:

``````k = 3 # a constant, how many times we train dis before training gen
for iteration in range(1, niter+1):
...
if iteration % (k+1) ==0:
# train dis
else:
# train gen``````

Also, to improve the stability of training, we can collect recent fake/generated images in a pool and sample the network inputs from it:

``````fake_pool = []
for iteration in range(1, niter+1):
...

# previous version:
#     gen_img = self.generate(Z)

# current version:
fake_pool.extend(self.generate(Z))
fake_pool = fake_pool[-pool_size:]
gen_img = np.array(fake_pool)[np.random.choice(len(fake_pool), size=(nbatch,), replace=False)]

...``````

### Postscript:

Other training tricks can be found in here.

Based on this framework, I also implemented Conditional GANInfoGAN, and other variety of GAN with Keras-1.x API in `legacy/`. Although those codes are not maintained and might not even working with Keras-2.x API, you can check them for inspiration.

I used Theano before Keras and it was taxing to build a deep neural network with raw Theano, even with Tensorflow. I implemented some keras-like libraries but it was not as well organized and forward-looking as Keras. So I instead turned to the Keras Community.

None