This tutorial is to guide you how to implement GAN with Keras. The complete code can be access in my github repository. If you are not familiar with GAN, please check the first part of this post or another blog to get the gist of GAN.
Generator
The generator is used to generate images from noise. Following DCGAN, one feasible generator in GAN/models/gen.py looks like :
def basic_gen(input_shape, img_shape, nf=128, scale=4, FC=[], use_upsample=False):
dim, h, w = img_shape
img = Input(input_shape)
x = img
for fc_dim in FC:
x = Dense(fc_dim)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dense(nf*2**(scale-1)*(h/2**scale)*(w/2**scale))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Reshape((nf*2**(scale-1), h/2**scale, w/2**scale))(x)
for s in range(scale-2, -1, -1):
# up sample can elimiate the checkbroad artifact
# http://distill.pub/2016/deconv-checkerboard/
if use_upsample:
x = UpSampling2D()(x)
x = Conv2D(nf*2**s, (3,3), padding='same')(x)
else:
x = Deconv2D(nf*2**s, (3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
if use_upsample:
x = UpSampling2D()(x)
x = Conv2D(dim, (3, 3), padding='same')(x)
else:
x = Deconv2D(dim, (3, 3), strides=(2, 2), padding='same')(x)
x = Activation('tanh')(x)
return Model(img, x)
- The
scalemeans how many times the image will be scaled up. FCrepresents the fully connected network before sampling up.use_upsampleis an option to useupsample+convto replacedeconvto get a better image quality. Please check this post for details.- I use
tanhas activation function, which also means real images should be normalized to.
Discriminator
The discriminator is simpler than a generator, in GAN/models/dis.py I implemented a very simple discriminator:
def basic_dis(input_shape, nf=128, scale=4, FC=[], bn=True):
dim, h, w = input_shape
img = Input(input_shape)
x = img
for s in range(scale):
x = Conv2D(nf*2**s, (5, 5), strides=(2, 2), padding='same')(x)
if bn: x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Flatten()(x)
for fc in FC:
x = Dense(fc)(x)
if bn: x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Dense(1, activation='sigmoid')(x)
return Model(img, x)
GAN
Code of GAN is in GAN/models/GAN.py.
To train the generator, we first have to connect it with discriminator by
gendis = gan = Sequential([generator, discriminator])
For simplicity, we use gen to refer to the generator and dis to refer to the discriminator. During the training of gen, the weights of dis should be fixed. We can enforce that by setting the trainable flag before compiling.
dis.trainable = False
gendis.compile(optimizer=opt, loss='binary_crossentropy')
To train the discriminator, we should separate the real and fake images as suggested in this blog. If you mix two types of images in one batch, you might find the discrepancy of features in batch normalization layers is too significant, which makes the discrimination so simple that nothing is learned.
In my code, I use a two-branch-model to forward two types of images separately:
shape = dis.get_input_shape_at(0)[1:]
gen_input, real_input = Input(shape), Input(shape)
dis2batch = Model([gen_input, real_input], [dis(gen_input), dis(real_input)])
dis.trainable = True
dis2batch.compile(optimizer=opt, loss='binary_crossentropy', metrics=['binary_accuracy'])
To sum up:
gen_trainner = gendis
dis_trainner = dis2batch
Training
The simplest code to train GAN might look like this:
for iteration in range(1, niter+1):
print 'iteration', iteration
real_img = data_generator(nbatch)
Z = np.random.uniform(-1., 1., size=(nbatch, self.coding)).astype('float32')
gen_img = self.generate(Z)
y = np.ones((nbatch, 1))
g_loss = self.gen_trainner.train_on_batch(Z, y)
gen_y = np.zeros((nbatch, 1))
real_y = np.ones((nbatch, 1))
d_loss = self.dis_trainner.train_on_batch([gen_img, real_img], [gen_y, real_y])
Suggested by many works, the discriminator should be trained much more times than the generator. To do that, we can modify the code:
k = 3 # a constant, how many times we train dis before training gen
for iteration in range(1, niter+1):
...
if iteration % (k+1) ==0:
# train dis
else:
# train gen
Also, to improve the stability of training, we can collect recent fake/generated images in a pool and sample the network inputs from it:
fake_pool = []
for iteration in range(1, niter+1):
...
# previous version:
# gen_img = self.generate(Z)
# current version:
fake_pool.extend(self.generate(Z))
fake_pool = fake_pool[-pool_size:]
gen_img = np.array(fake_pool)[np.random.choice(len(fake_pool), size=(nbatch,), replace=False)]
...
Postscript:
Other training tricks can be found in here.
Based on this framework, I also implemented Conditional GAN, InfoGAN, and other variety of GAN with Keras-1.x API in legacy/. Although those codes are not maintained and might not even working with Keras-2.x API, you can check them for inspiration.
I used Theano before Keras and it was taxing to build a deep neural network with raw Theano, even with Tensorflow. I implemented some keras-like libraries but it was not as well organized and forward-looking as Keras. So I instead turned to the Keras Community.