VAE
1. Auto Encoder
Before we formally introduce the VAE, let’s first look at the structure of AE. It consists of an Encoder and a Decoder.
The input data will be input to the Encoder to get the ‘hidden states’. Then the Decoder will eat these ‘hidden states’ to recover the input which means the output should be as close to the input as possible.
Usually, we hope that the dimension of ‘hidden states’ will be less than the input to achieve dimension reduction.
2. Variational Auto Encoder
The difference here between VAE and AE is that the ‘hidden states’ of VAE is not a fixed variable but a distribution.
For example, to get the GMM ‘hidden states’, the Encoder will output a ‘mean vector’ and a ‘covariance vector’.
The loss function here includes ‘Recontruction Loss’ and ‘KL Divergence’. The ‘RL’ is the same as AE. The ‘KLD’ is used to control the ‘hidden states’ to be close to GMM.
Reparameterization Trick
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib
matplotlib.use('AGG')#或者PDF, SVG或PS
import matplotlib.pyplot as plt
from torchvision.utils import save_image
class Normal(object):
def __init__(self, mu, sigma, log_sigma, v=None, r=None):
self.mu = mu
self.sigma = sigma # either stdev diagonal itself, or stdev diagonal from decomposition
self.logsigma = log_sigma
dim = mu.get_shape()
if v is None:
v = torch.FloatTensor(*dim)
if r is None:
r = torch.FloatTensor(*dim)
self.v = v
self.r = r
class Encoder(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(Encoder, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
x = F.relu(self.linear1(x))
return F.relu(self.linear2(x))
class Decoder(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(Decoder, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
x = F.relu(self.linear1(x))
return F.relu(self.linear2(x))
class VAE(torch.nn.Module):
latent_dim = 8
def __init__(self, encoder, decoder):
super(VAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
self._enc_mu = torch.nn.Linear(100, 8)
self._enc_log_sigma = torch.nn.Linear(100, 8)
def _sample_latent(self, h_enc):
"""
Return the latent normal sample z ~ N(mu, sigma^2)
"""
mu = self._enc_mu(h_enc)
log_sigma = self._enc_log_sigma(h_enc)
sigma = torch.exp(log_sigma)
std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
self.z_mean = mu
self.z_sigma = sigma
return mu + sigma * Variable(std_z, requires_grad=False) # Reparameterization trick
def forward(self, state):
h_enc = self.encoder(state)
z = self._sample_latent(h_enc)
return self.decoder(z)
def latent_loss(z_mean, z_stddev):
mean_sq = z_mean * z_mean
stddev_sq = z_stddev * z_stddev
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)
if __name__ == '__main__':
input_dim = 28 * 28
batch_size = 32
#mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)
mnist = torchvision.datasets.MNIST(root='../data/', train=True, transform = transforms.Compose([transforms.ToTensor()]), download=True)
dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,
shuffle=True, num_workers=8)
print('Number of samples: ', len(mnist))
encoder = Encoder(input_dim, 100, 100)
decoder = Decoder(8, 100, input_dim)
vae = VAE(encoder, decoder)
criterion = nn.MSELoss()
optimizer = optim.Adam(vae.parameters(), lr=0.0002)
l = None
losse = []
for epoch in range(10000):
for i, data in enumerate(dataloader, 0):
inputs, classes = data
inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
optimizer.zero_grad()
dec = vae(inputs)
ll = latent_loss(vae.z_mean, vae.z_sigma)
loss = criterion(dec, inputs) + ll
loss.backward()
optimizer.step()
l = loss.item()
print(epoch, l)
losse.append(l)
if epoch % 10 == 0:
plt.plot(losse, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Losses')
plt.savefig('Losses.jpg')
#plt.show()
plt.close()
plt.subplot(2,1,1)
plt.imshow(inputs.data[0].numpy().reshape(28, 28), cmap='gray')
plt.subplot(2,1,2)
plt.imshow(vae(inputs).data[0].numpy().reshape(28, 28), cmap='gray')
plt.savefig('picture_G.jpg')
#plt.show()
plt.close()
if epoch % 50 == 0:
torch.save(vae.state_dict(), 'vae.pth')