VAE Variants
Variational Autoencoders learn latent representations through probabilistic encoding. Various extensions improve upon the basic VAE for specific applications and better generation quality.
Standard VAE
The basic VAE learns to encode and decode through a variational bottleneck.
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim=512, latent_dim=20):
super().__init__()
self.latent_dim = latent_dim
# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar
def vae_loss(recon, x, mu, logvar, beta=1.0):
recon_loss = F.binary_cross_entropy(recon, x, reduction="sum")
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + beta * kl_lossConvolutional VAE
CNNs provide better inductive bias for image data.
class ConvVAE(nn.Module):
def __init__(self, in_channels=3, latent_dim=256):
super().__init__()
self.latent_dim = latent_dim
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
# Decoder
self.fc_decode = nn.Linear(latent_dim, 256 * 4 * 4)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, in_channels, 4, stride=2, padding=1),
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = self.fc_decode(z)
h = h.view(-1, 256, 4, 4)
return self.decoder(h)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvarBeta-VAE
Beta-VAE encourages disentangled representations by increasing KL weight.
class BetaVAE(nn.Module):
def __init__(self, input_dim, latent_dim=10, beta=4.0):
super().__init__()
self.beta = beta
self.latent_dim = latent_dim
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU()
)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid()
)
def forward(self, x):
h = self.encoder(x)
mu, logvar = self.fc_mu(h), self.fc_logvar(h)
z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
recon = self.decoder(z)
return recon, mu, logvar
def loss(self, recon, x, mu, logvar):
recon_loss = F.mse_loss(recon, x, reduction="sum")
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + self.beta * kl_lossVQ-VAE
Vector Quantized VAE uses discrete latent codes instead of continuous.
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.commitment_cost = commitment_cost
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
def forward(self, z):
z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
distances = (
z_flat.pow(2).sum(dim=1, keepdim=True)
- 2 * z_flat @ self.embedding.weight.t()
+ self.embedding.weight.pow(2).sum(dim=1)
)
indices = distances.argmin(dim=1)
z_q = self.embedding(indices).view(z.shape[0], z.shape[2], z.shape[3], -1)
z_q = z_q.permute(0, 3, 1, 2)
commitment_loss = F.mse_loss(z_q.detach(), z)
codebook_loss = F.mse_loss(z_q, z.detach())
loss = codebook_loss + self.commitment_cost * commitment_loss
z_q = z + (z_q - z).detach()
return z_q, loss, indices
class VQVAE(nn.Module):
def __init__(self, in_channels=3, hidden_dim=128, num_embeddings=512, embedding_dim=64):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_dim, embedding_dim, 1)
)
self.quantizer = VectorQuantizer(num_embeddings, embedding_dim)
self.decoder = nn.Sequential(
nn.Conv2d(embedding_dim, hidden_dim, 1),
nn.ReLU(),
nn.ConvTranspose2d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(hidden_dim, in_channels, 4, stride=2, padding=1)
)
def forward(self, x):
z_e = self.encoder(x)
z_q, vq_loss, indices = self.quantizer(z_e)
recon = self.decoder(z_q)
return recon, vq_loss, indices
def loss(self, recon, x, vq_loss):
recon_loss = F.mse_loss(recon, x)
return recon_loss + vq_lossKey Takeaways
VAE variants address different limitations of the standard VAE. Convolutional VAEs provide better image modeling through spatial inductive bias. Beta-VAE encourages disentangled representations with increased KL weighting. VQ-VAE uses discrete latent codes for sharper reconstructions and easier autoregressive modeling. Each variant trades off between reconstruction quality, latent structure, and generation diversity.