Implement improved WGAN with Keras-2.x

Posted on May 28, 2017

This tutorial is based on Improved Training of Wasserstein GANs (IWGAN). You can learn how to customized layers and how to build IWGAN with Keras. The code can be accessed in my github repository.

Notice: Keras updates so fast and you can already find some layers (e.g. the subtraction layer) in the official library. If you are new to GAN and Keras, please implement GAN first. After that, check the GardNorm layer in this post, which is the most essential part in IWGAN.

The classical GAN use following objective, which can be interpreted as “minimizing JS divergence between fake and real distributions”.

In WGAN, they suggest that JS Divergence can not provide enough information when the discrepancy is too large. In contrast, Wasserstein Distance is much more accurate even when two distributions do not overlap. However, it is impossible to calculate Wasserstein Distance directly. Hence they tried to optimize its dual-form:

, where \mathcal{D} is the set of 1-Lipschitz functions. To ensure D\in \mathcal{D} after training, WGAN clipped the weights of discriminator into a small range. 


However, according to the authors of IWGAN, weight clipping can results in undesirable behaviors.

Based on the following discovery:

They came up with an improved objective:

When the gradient penalty is fully optimized, the discriminator should act like an 1-Lipschitz function.

The pseudo-code is also given in their paper:



Compared with basic GAN, there are some component that are not provided by Keras in IWGAN, which we have to customize:

  • Use mean of output as loss.
  • Merging two variables through subtraction.
  • Use gradient as loss.


Use mean of output as loss (Used in line 7, line 12)

Keras provides various losses, but none of them can directly use the output as a loss function. Therefore, we have to customize the loss function:

def multiple_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)

def mean_loss(y_true, y_pred):
    return K.mean(y_pred)

The second loss will omit the y_ture but sometimes we want to multiple some constants (like minus one in line 12 of the pesudo-code) with the output. Note that K.mean calculate the mean of all values, and if your output has multiple dimensions, you might want to sum along the first axis and then calculate the mean value. In this case the output of discriminator/critic has only one dimension.


Merging two variables through subtraction (Used in line7)

We have to calculate D_w(\widetilde{x})-D_w(x) in line 7 and use the multiple_loss or the mean_loss to use the output as loss. Keras does not provide merging through subtracting. Looking into the source code of Keras, we can find that the merging layer can be easily defined. More specifically, we only have to override the _merge_function method of _Merge class:

# this is the source code of merge.ADD
#class Add(_Merge):
#    def _merge_function(self, inputs):
#        output = inputs[0]
#        for i in range(1, len(inputs)):
#            output = output+inputs[i]
#        return output

from keras.layers.merge import _Merge

class Subtract(_Merge):
    def _merge_function(self, inputs):
        output = inputs[0]
        for i in range(1, len(inputs)):
            output = output-inputs[i]
        return output

The code might be a little bit confusing. The variable inputs is actually a list of multiple (can be more than two) tensors. And we take the first tensor as output and subtract/add other tensors from/to it.


Use gradient as loss

To calculate D_w(\widetilde{x})-D_w(x)+\lambda (\|\bigtriangledown_{\hat{x}}D_w(\hat{x})\|_2-1)^2, we first calculate D_w(\widetilde{x})-D_w(x) and \|\bigtriangledown_{\hat{x}}D_w(\hat{x})\|_2 separately, apply the mean_loss and MSE(Mean Squared Error) respectively, and use the loss_weights in Keras’ compile function.

What we left now is how to calculate \|\bigtriangledown_{\hat{x}}D_w(\hat{x})\|_2. Keras does not provide this complicated operation so we have to customize a layer to do that. The code looks like:

from keras import backend as K
from keras.engine.topology import Layer
import numpy as np
class GradNorm(Layer):
    def __init__(self, **kwargs):
        super(GradNorm, self).__init__(**kwargs)

    def build(self, input_shapes):
        super(GradNorm, self).build(input_shapes)

    def call(self, inputs):
        target, wrt = inputs
        grads = K.gradients(target, wrt)
        assert len(grads) == 1
        grad = grads[0]
        return K.sqrt(K.sum(K.batch_flatten(K.square(grad)), axis=1, keepdims=True))

    def compute_output_shape(self, input_shapes):
        return (input_shapes[1][0], 1)

We take two variable as input, one is the target tensor and another one is the wrt tensor.


With those components, we can define our model as:

        gen, dis = self.generator, self.discriminator
        gendis = Sequential([gen, dis])

        dis.trainable = False 
        gendis.compile(optimizer=opt, loss=multiple_loss) # output: D(G(Z)) ===(y_true:-1*ones)===>  Loss:(-1) * D(G(Z)) 

        shape = dis.get_input_shape_at(0)[1:]
        gen_input, real_input, interpolation = Input(shape), Input(shape), Input(shape)
        sub = Subtract()([dis(gen_input), dis(real_input)])
        norm = GradNorm()([dis(interpolation), interpolation])
        dis2batch = Model([gen_input, real_input, interpolation], [sub, norm]) 
                            # output: D(G(Z))-D(X), norm ===(y_true:nones, ones)==> Loss: D(G(Z))-D(X)+lmbd*(norm-1)**2
        dis.trainable = True
        dis2batch.compile(optimizer=opt, loss=[mean_loss,'mse'], loss_weights=[1.0, lmbd])


Thanks for the comment of from @bpgw. To use that model, please check my code for more details. The key lines are:

# generator
y = np.ones((nbatch, 1)) * (-1)
g_loss = self.gen_trainner.train_on_batch(Z, y) # output: D(G(Z)) ===(-1*ones)===>  Loss:(-1) * D(G(Z))

# discriminator
epsilon = np.random.uniform(0, 1, size=(nbatch,1,1,1))
interpolation = epsilon*real_img + (1-epsilon)*gen_img
d_loss, d_diff, d_norm = self.dis_trainner.train_on_batch([gen_img, real_img, interpolation], [np.ones((nbatch, 1))]*2)