Notice: Keras updates so fast and you can already find some layers (e.g. the subtraction layer) in the official library. If you are new to GAN and Keras, please implement GAN first. After that, check the GardNorm layer in this post, which is the most essential part in IWGAN.
The classical GAN use following objective, which can be interpreted as “minimizing JS divergence between fake and real distributions”.
In WGAN, they suggest that JS Divergence can not provide enough information when the discrepancy is too large. In contrast, Wasserstein Distance is much more accurate even when two distributions do not overlap. However, it is impossible to calculate Wasserstein Distance directly. Hence they tried to optimize its dual-form:
, where is the set of 1-Lipschitz functions. To ensure after training, WGAN clipped the weights of discriminator into a small range.
However, according to the authors of IWGAN, weight clipping can results in undesirable behaviors.
Based on the following discovery:
They came up with an improved objective:
When the gradient penalty is fully optimized, the discriminator should act like an 1-Lipschitz function.
The pseudo-code is also given in their paper:
Compared with basic GAN, there are some component that are not provided by Keras in IWGAN, which we have to customize:
- Use mean of output as loss.
- Merging two variables through subtraction.
- Use gradient as loss.
Use mean of output as loss (Used in line 7, line 12)
Keras provides various losses, but none of them can directly use the output as a loss function. Therefore, we have to customize the loss function:
def multiple_loss(y_true, y_pred): return K.mean(y_true*y_pred) def mean_loss(y_true, y_pred): return K.mean(y_pred)
The second loss will omit the y_ture but sometimes we want to multiple some constants (like minus one in line 12 of the pesudo-code) with the output. Note that
K.mean calculate the mean of all values, and if your output has multiple dimensions, you might want to sum along the first axis and then calculate the mean value. In this case the output of discriminator/critic has only one dimension.
Merging two variables through subtraction (Used in line7)
We have to calculate in line 7 and use the
multiple_loss or the
mean_loss to use the output as loss. Keras does not provide merging through subtracting. Looking into the source code of Keras, we can find that the merging layer can be easily defined. More specifically, we only have to override the
_merge_function method of
# this is the source code of merge.ADD #class Add(_Merge): # def _merge_function(self, inputs): # output = inputs # for i in range(1, len(inputs)): # output = output+inputs[i] # return output from keras.layers.merge import _Merge class Subtract(_Merge): def _merge_function(self, inputs): output = inputs for i in range(1, len(inputs)): output = output-inputs[i] return output
The code might be a little bit confusing. The variable
inputs is actually a list of multiple (can be more than two) tensors. And we take the first tensor as output and subtract/add other tensors from/to it.
Use gradient as loss
To calculate , we first calculate and separately, apply the
mean_loss and MSE(Mean Squared Error) respectively, and use the
loss_weights in Keras’ compile function.
What we left now is how to calculate . Keras does not provide this complicated operation so we have to customize a layer to do that. The code looks like:
from keras import backend as K from keras.engine.topology import Layer import numpy as np class GradNorm(Layer): def __init__(self, **kwargs): super(GradNorm, self).__init__(**kwargs) def build(self, input_shapes): super(GradNorm, self).build(input_shapes) def call(self, inputs): target, wrt = inputs grads = K.gradients(target, wrt) assert len(grads) == 1 grad = grads return K.sqrt(K.sum(K.batch_flatten(K.square(grad)), axis=1, keepdims=True)) def compute_output_shape(self, input_shapes): return (input_shapes, 1)
We take two variable as input, one is the target tensor and another one is the wrt tensor.
With those components, we can define our model as:
gen, dis = self.generator, self.discriminator gendis = Sequential([gen, dis]) dis.trainable = False gendis.compile(optimizer=opt, loss=multiple_loss) # output: D(G(Z)) ===(y_true:-1*ones)===> Loss:(-1) * D(G(Z)) shape = dis.get_input_shape_at(0)[1:] gen_input, real_input, interpolation = Input(shape), Input(shape), Input(shape) sub = Subtract()([dis(gen_input), dis(real_input)]) norm = GradNorm()([dis(interpolation), interpolation]) dis2batch = Model([gen_input, real_input, interpolation], [sub, norm]) # output: D(G(Z))-D(X), norm ===(y_true:nones, ones)==> Loss: D(G(Z))-D(X)+lmbd*(norm-1)**2 dis.trainable = True dis2batch.compile(optimizer=opt, loss=[mean_loss,'mse'], loss_weights=[1.0, lmbd])
Thanks for the comment of from @bpgw. To use that model, please check my code for more details. The key lines are:
# generator y = np.ones((nbatch, 1)) * (-1) g_loss = self.gen_trainner.train_on_batch(Z, y) # output: D(G(Z)) ===(-1*ones)===> Loss:(-1) * D(G(Z)) # discriminator epsilon = np.random.uniform(0, 1, size=(nbatch,1,1,1)) interpolation = epsilon*real_img + (1-epsilon)*gen_img d_loss, d_diff, d_norm = self.dis_trainner.train_on_batch([gen_img, real_img, interpolation], [np.ones((nbatch, 1))]*2)