This tutorial is to guide you how to implement GAN with Keras. The complete code can be access in my github repository. If you are not familiar with GAN, please check the first part of this post or another blog to get the gist of GAN.
The generator is used to generate images from noise. Following DCGAN, one feasible generator in
GAN/models/gen.py looks like :
def basic_gen(input_shape, img_shape, nf=128, scale=4, FC=, use_upsample=False): dim, h, w = img_shape img = Input(input_shape) x = img for fc_dim in FC: x = Dense(fc_dim)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Dense(nf*2**(scale-1)*(h/2**scale)*(w/2**scale))(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Reshape((nf*2**(scale-1), h/2**scale, w/2**scale))(x) for s in range(scale-2, -1, -1): # up sample can elimiate the checkbroad artifact # http://distill.pub/2016/deconv-checkerboard/ if use_upsample: x = UpSampling2D()(x) x = Conv2D(nf*2**s, (3,3), padding='same')(x) else: x = Deconv2D(nf*2**s, (3, 3), strides=(2, 2), padding='same')(x) x = BatchNormalization()(x) x = Activation('relu')(x) if use_upsample: x = UpSampling2D()(x) x = Conv2D(dim, (3, 3), padding='same')(x) else: x = Deconv2D(dim, (3, 3), strides=(2, 2), padding='same')(x) x = Activation('tanh')(x) return Model(img, x)
scalemeans how many times the image will be scaled up.
FCrepresents the fully connected network before sampling up.
use_upsampleis an option to use
deconvto get a better image quality. Please check this post for details.
- I use
tanhas activation function, which also means real images should be normalized to .
The discriminator is simpler than a generator, in
GAN/models/dis.py I implemented a very simple discriminator:
def basic_dis(input_shape, nf=128, scale=4, FC=, bn=True): dim, h, w = input_shape img = Input(input_shape) x = img for s in range(scale): x = Conv2D(nf*2**s, (5, 5), strides=(2, 2), padding='same')(x) if bn: x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) x = Flatten()(x) for fc in FC: x = Dense(fc)(x) if bn: x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) x = Dense(1, activation='sigmoid')(x) return Model(img, x)
Code of GAN is in
To train the generator, we first have to connect it with discriminator by
gendis = gan = Sequential([generator, discriminator])
For simplicity, we use
gen to refer to the generator and
dis to refer to the discriminator. During the training of
gen, the weights of
dis should be fixed. We can enforce that by setting the
trainable flag before compiling.
dis.trainable = False gendis.compile(optimizer=opt, loss='binary_crossentropy')
To train the discriminator, we should separate the real and fake images as suggested in this blog. If you mix two types of images in one batch, you might find the discrepancy of features in batch normalization layers is too significant, which makes the discrimination so simple that nothing is learned.
In my code, I use a two-branch-model to forward two types of images separately:
shape = dis.get_input_shape_at(0)[1:] gen_input, real_input = Input(shape), Input(shape) dis2batch = Model([gen_input, real_input], [dis(gen_input), dis(real_input)]) dis.trainable = True dis2batch.compile(optimizer=opt, loss='binary_crossentropy', metrics=['binary_accuracy'])
To sum up:
gen_trainner = gendis dis_trainner = dis2batch
The simplest code to train GAN might look like this:
for iteration in range(1, niter+1): print 'iteration', iteration real_img = data_generator(nbatch) Z = np.random.uniform(-1., 1., size=(nbatch, self.coding)).astype('float32') gen_img = self.generate(Z) y = np.ones((nbatch, 1)) g_loss = self.gen_trainner.train_on_batch(Z, y) gen_y = np.zeros((nbatch, 1)) real_y = np.ones((nbatch, 1)) d_loss = self.dis_trainner.train_on_batch([gen_img, real_img], [gen_y, real_y])
Suggested by many works, the discriminator should be trained much more times than the generator. To do that, we can modify the code:
k = 3 # a constant, how many times we train dis before training gen for iteration in range(1, niter+1): ... if iteration % (k+1) ==0: # train dis else: # train gen
Also, to improve the stability of training, we can collect recent fake/generated images in a pool and sample the network inputs from it:
fake_pool =  for iteration in range(1, niter+1): ... # previous version: # gen_img = self.generate(Z) # current version: fake_pool.extend(self.generate(Z)) fake_pool = fake_pool[-pool_size:] gen_img = np.array(fake_pool)[np.random.choice(len(fake_pool), size=(nbatch,), replace=False)] ...
Other training tricks can be found in here.
Based on this framework, I also implemented Conditional GAN, InfoGAN, and other variety of GAN with Keras-1.x API in
legacy/. Although those codes are not maintained and might not even working with Keras-2.x API, you can check them for inspiration.
I used Theano before Keras and it was taxing to build a deep neural network with raw Theano, even with Tensorflow. I implemented some keras-like libraries but it was not as well organized and forward-looking as Keras. So I instead turned to the Keras Community.