[DR021] SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient

Posted on October 14, 2017

Previous methods to train RNN is to maximize the log predictive likelihood of each true token in the training sequence given the previously observed tokens. However, this approach suffers from exposure [1] bias during inference: when generating a long-term sequence, we take the output of RNN as its future input, and such input pattern might never be observed in the training set. Therefore, this discrepancy can accumulate as the generated sequence grows. In [1], they proposed scheduled sampling (SS) to address this problem by including some synthetic data during training phrase. But SS is later proved to be an inconsistent training strategy to apply. Another solution is to build the loss function on the entire generated sequence instead of each transition (e.g., BLEU). But this metric is not accurate when evaluating complex generation tasks like music or dialog.

In GAN, the discriminator can locate the difference between the real data and the synthetic one precisely by being trained iteratively with the generator. However, in this discussion, Goodfellow pointed out that GAN works poorly on generate discrete tokens since the guidance is too slight to cause a change in the limited dictionary space. And he recommended using RL to train GAN to generate discrete tokens.

 

As shown in the figure, this paper regards the generator as a policy in traditional RL learning, where previous tokens are the states (stored in the hidden states) and the action is the next token to generate. The discriminator is fed with both real and synthetic data to local the difference. To evaluate some partial sequence, they use another generator to fill the rest by Monte Carl searching. In this paper, two generators share the same parameter.

 

The pseudo-code looks like:

MLE means the traditional maximum likelihood estimation method. The roll-up policy is the one used to fill the partial sequence. The Q-function is defined by:

where MC is the sampling set:

And standard Policy Gradient

and discriminator loss

are used.

As for the network structure, standard LSTM is employed as the building block of the generator policy. But CNN whose filter size equals to the number of tokens is used as the discriminator and max-over-time pooling is applied after convolution. I guess they had the same trouble as we did when using RNN as the discriminator: RNN compress every element in the sequence in a coding to judge whether a sequence is real or fake, which requires the generator to unroll and to use the gradient in a very difficult way. CNN is relatively more direct when propagating the information forward and backward by limiting the spatial/temporal relationship in a controllable range (i.e. filter size).

 

To compare different generative models, they first use a randomly initialized LSTM as the true distribution and generate sufficient data from it. I think its unusual, but it is reasonable to some extents. Another thing in the experiment that is worth remembering is the following description:

 

 

[1]: Scheduled sampling for sequence prediction with recurrent neural networks.

[DR022]: The length of the generated music in the experiment is too short (32 notes) and simple (88 events). TBA

None