sudo make

Inception Score — evaluating the realism of your GAN

An excellent paper by Salimans et al. not only explores a variery of methods for improving the visual quality of generated images, but also proposes a very cool metric called Inception Score. It's particularly interesting because it seems to be a very reliable way to "quantify realism" in GANs, tackling a problem that existed since the invention of adversarial training.

In this article I'm going to explain the Inception Score and the reasoning behind at.

Introduction

Visual quality of generated images is a highly subjective thing — there's no definitive solution how to formalize it, even for research purposes — which is why every paper on GANs contains sampled images.
gan-visual
Illustration from Johnson et al. (2016) — you get the idea

However, while training and evaluating GANs, you'd be more interested in a more or less reliable way to filter out weak (e.g. collapsed) networks, just to save time. Salimans et al. provide a seemingly convoluted, but as you'll see in a moment, a very intuitive formula:

$$IS(x) = \exp(\mathbb{E}_x\Big[KL\Big(p(y|x) || p(y)\Big)\Big])$$

But before we deconstruct it, let's understand the bigger picture.

Motivation

Consider this setting: a zoo of GANs you've trained has generated several sets of images: $X_1, X_2, ..., X_N$ that are trying to mimic the original set $X_{real}$. If you had a perfect way to rank the realism if these sets, let's say, a function $\rho(X)$, then $\rho(X_{real})$ would, obviously, be the highest of all. The real question is:

How can such function be formulated in terms of statistics / information theory?

The answer depends on the requirements for the images. Two criteria immediately come to mind:

  1. A human, looking at a separate image, would be able to confidently determine what's in there (saliency).
  2. A human, looking at a set of various images, would say that the set has lots of different objects (diversity).

At least, that's what everyone expects from a good generative model. Those who have tried training GANs themselves have immediately noted the fact that usually you end up getting only one criterion covered.

Saliency vs. diversity

Broadly speaking, these two criteria are represented by two components of the formula:

  1. Saliency is expressed as $p(y|x)$ — a distribution of classes for any individual image should have low entropy (think of it as a single high score and the rest very low).
  2. Diversity is expressed is $p(x)$ — overall distribution of classes across the sampled data should have high entropy, which would mean the absense of dominating classes and something closer to a well-balanced training set.

Kullback-Leibler distance

KL-distance is a measure of information loss occuring when instead of a true empiric distribution an approximation is used. Note that in the "Saliency vs. diversity" both distributions can be used as approximations of each other: you can guess the most probable class of an image using a mean score distribution or you can estimate the mean by looking at a separate example. Ideally, neither of these approximations shoud be good, that's why the formula is using KL-distance.

Inception Score is a measure of "on average, how different is the score distribution for a generated image from the overall class balance"

And that's it. One important thing to keep in mind (and, actually, the most fascinating among these), is that to compute this score for a set of generated images you need a good image classifier. Hence the name of the metric — for calculating the distributions the authors used a pretrained Inception.

Implementation

And, as usual, a PyTorch implementation:

import math
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import inception_v3


net = inception_v3(pretrained=True).cuda()


def inception_score(images, batch_size=5):
    scores = []
    for i in range(int(math.ceil(float(len(images)) / float(batch_size)))):
        batch = Variable(torch.cat(images[i * batch_size: (i + 1) * batch_size], 0))
        s, _ = net(batch)  # skipping aux logits
        scores.append(s)
    p_yx = F.softmax(torch.cat(scores, 0), 1)
    p_y = p_yx.mean(0).unsqueeze(0).expand(p_yx.size(0), -1)
    KL_d = p_yx * (torch.log(p_yx) - torch.log(p_y))
    final_score = KL_d.mean()
    return final_score
Author image
About Roman Trusov