Training a GAN is tricky, unstable process, especially when the goal is to get the generator to produce diverse images from the target distribution. In practice, in deep convolutional GANs generators overfit to their respective discriminators, which gives lots of repetitive generated images.
Generative Adversarial Networks, Source: ResearchGate
Feature matching is one of the methods that not only improve the stability of GANs, but do it in a way that helps to use them in semi-supervised training when you don't have enough labeled data.
The difference between an ordinary GAN and a Feature-matching GAN is the training objective for generator. Ordinary generator tries to maximize the output of the discriminator ("false" -> "true" and vice versa). Feature matching implements the idea that for images to appear "natural", the statistics of output features for generated images should be similar to the statistics of output features for real images. In terms of layers and tensors, this boils down to minimizing mean squared error between output layers of some deep layer in the discriminator - preferably, right before the classification.
Makes sense, right?
It's a bit counterintuitive, but let's take a look at the code.
import torch from torch import nn # we keep the discriminator as simple as possible class Discriminator(nn.Module): def __init__(self, input_size, num_features): super().__init__() self.features = nn.Linear(input_size, num_features) self.classifier = nn.Sequential( nn.Linear(num_features, 2), nn.Sigmoid() ) def forward(self, x): # here's the main part # instead of performing the classification, # we return both outputs and the features f = self.features(x) cls = self.classifier(f) return cls, f # all nets, variables and optimizers are initialized as usual D = ... G = ... x, y = ... # criterion feature_matching_criterion = nn.MSELoss() fake_samples = G(noise) # perform sampling from generator real_samples = ... # perform sampling from real data fake_pred, fake_feats = D(fake_samples) real_pred, real_feats = D(real_samples) real_feats = real_feats.detach() # so that PyTorch will treat them as volatile # now, calculating the new objective fm_loss = feature_matching_criterion(fake_feats, real_feats) fm_loss.backward()
Same as every other trick for training GANs, this one doesn't automatically solve all problems with convergence, but it makes training more stable.
It's especially useful for experiments where your classifier underperforms because of the lack of training data - for example, when you are trying to prototype something for your project very quickly.
Using GANs for semi-supervised training is a somewhat underexplored area, the general idea is that if you make the discriminator classify not only fake and real images, but also assign correct labels to real ones, you can leverage lots of unlabeled data to improve the results of your classification.
Intuitively, if you manage to generate samples that have deep representations indistinguishable from the representations of real ones, it will force the discriminator to learn the underlying patters (distributions) more efficientlly - and evidently, more accurately than you'd do using only a supervised model.
To summarize, this method is a good alternative to traditional training of a generator in the following use cases:
- Unstable convergence of a generative model
- Additional objective to match the representations between real and fake images