1. Auto Encoder

Before we formally introduce the VAE, let’s first look at the structure of AE. It consists of an Encoder and a Decoder.

The input data will be input to the Encoder to get the ‘hidden states’. Then the Decoder will eat these ‘hidden states’ to recover the input which means the output should be as close to the input as possible.

Usually, we hope that the dimension of ‘hidden states’ will be less than the input to achieve dimension reduction.

2. Variational Auto Encoder

The difference here between VAE and AE is that the ‘hidden states’ of VAE is not a fixed variable but a distribution.

For example, to get the GMM ‘hidden states’, the Encoder will output a ‘mean vector’ and a ‘covariance vector’.

The loss function here includes ‘Recontruction Loss’ and ‘KL Divergence’. The ‘RL’ is the same as AE. The ‘KLD’ is used to control the ‘hidden states’ to be close to GMM.

Reparameterization Trick

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib
matplotlib.use('AGG')#或者PDF, SVG或PS
import matplotlib.pyplot as plt
from torchvision.utils import save_image

class Normal(object):
    def __init__(self, mu, sigma, log_sigma, v=None, r=None):
        self.mu = mu
        self.sigma = sigma  # either stdev diagonal itself, or stdev diagonal from decomposition
        self.logsigma = log_sigma
        dim = mu.get_shape()
        if v is None:
            v = torch.FloatTensor(*dim)
        if r is None:
            r = torch.FloatTensor(*dim)
        self.v = v
        self.r = r

class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.relu(self.linear2(x))

class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.relu(self.linear2(x))

class VAE(torch.nn.Module):
    latent_dim = 8

    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self._enc_mu = torch.nn.Linear(100, 8)
        self._enc_log_sigma = torch.nn.Linear(100, 8)

    def _sample_latent(self, h_enc):
        Return the latent normal sample z ~ N(mu, sigma^2)
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()

        self.z_mean = mu
        self.z_sigma = sigma

        return mu + sigma * Variable(std_z, requires_grad=False)  # Reparameterization trick

    def forward(self, state):
        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

if __name__ == '__main__':

    input_dim = 28 * 28
    batch_size = 32

    #mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)
    mnist = torchvision.datasets.MNIST(root='../data/', train=True, transform = transforms.Compose([transforms.ToTensor()]), download=True)
    dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,
                                             shuffle=True, num_workers=8)

    print('Number of samples: ', len(mnist))

    encoder = Encoder(input_dim, 100, 100)
    decoder = Decoder(8, 100, input_dim)
    vae = VAE(encoder, decoder)

    criterion = nn.MSELoss()

    optimizer = optim.Adam(vae.parameters(), lr=0.0002)
    l = None
    losse = []
    for epoch in range(10000):
        for i, data in enumerate(dataloader, 0):
            inputs, classes = data
            inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
            dec = vae(inputs)
            ll = latent_loss(vae.z_mean, vae.z_sigma)
            loss = criterion(dec, inputs) + ll
            l = loss.item()
        print(epoch, l)

        if epoch % 10 == 0:
          plt.plot(losse, '-')
          plt.imshow(inputs.data[0].numpy().reshape(28, 28), cmap='gray')
          plt.imshow(vae(inputs).data[0].numpy().reshape(28, 28), cmap='gray')
        if epoch % 50 == 0:
          torch.save(vae.state_dict(), 'vae.pth')