I am trying to find a low-dimensional latent space representation for a bunch of simulated data. No matter what VAE architecture I try and no matter how I tweak it, the output of the VAE is always the same, where all values hover around 0; it does not reconstruct the input at all. As far as I can understand, the problem is known as posterior collapse in the literature.
I use ray.tune to find hyper-parameters that work and the lowest validation loss I ever get is val_loss ~ 1. I picked one hyper parameter configuration that yields a val_loss ~ 1 and used it to train a VAE in the script below. The data in this example is 15 dimensional and drawn from a correlated multivariate normal distribution.
import math
import os
import numpy as np
from scipy.stats import random_correlation, loguniform
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
import time
class MDVAE(nn.Module):
def init(
self,
n_data: int,
n_latent = 0.3,
n_hidden = 0.6,
n_hidden_layers = 1,
bias = True,
# activation=nn.LeakyReLU(0.01),
activation=nn.LeakyReLU(0.01),
):
super().init()
if isinstance(n_latent, float):
n_latent = math.ceil(n_latent * n_data)
if isinstance(n_hidden, float):
n_hidden = math.ceil(n_hidden * n_data)
if (n_latent > n_hidden) or (n_latent > n_data):
raise ValueError
endocer_layers = [nn.Linear(n_data, n_hidden, bias=bias), activation]
for i in range(n_hidden_layers):
endocer_layers.extend([nn.Linear(n_hidden, n_hidden, bias=bias), activation])
self.encoder = nn.Sequential(*endocer_layers)
self.mean_lay = nn.Linear(n_hidden, n_latent)
self.log_var_lay = nn.Linear(n_hidden, n_latent)
decoder_layers = []
for layer in endocer_layers:
if layer == activation:
continue
decoder_layers.extend([nn.Linear(layer.out_features, layer.in_features), activation])
decoder_layers.extend([nn.Linear(n_latent, n_hidden)])
self.decoder = nn.Sequential(*decoder_layers[::-1])
def encode_and_sample(self, x):
hidden = self.encoder(x)
mean = self.mean_lay(hidden)
log_var = self.log_var_lay(hidden)
# reparametrization trick
std = torch.exp(0.5 * log_var)
epsilon = torch.randn_like(std)
z = mean + std * epsilon
return z, mean, log_var
def forward(self, x):
z, mean, log_var = self.encode_and_sample(x)
return self.decoder(z), mean, log_var
class VAE_Dataset(Dataset):
def init(self, data: torch.Tensor):
self.data = data
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
return self.data[idx], self.data[idx]
_PRR = lambda x: float(x.to('cpu').data.numpy())
def ray_train(config, cwd, n_data, n_epoch=5, show_progress=False):
n_hidden = float(config.get('n_hidden', 0.6))
n_latent = int(config.get('n_latent', 0.3))
n_hidden_layers = int(config.get('n_hidden_layers', 2))
batch_size = int(config.get('batch_size', 64))
learning_rate = float(config.get('learning_rate', 1e-3))
weight_decay = float(config.get('weight_decay', 1e-4))
KL_factor = float(config.get('KL_factor', 1.0))
train_ds = torch.load(os.path.join(cwd, 'train_ds.pt'))
val_ds = torch.load(os.path.join(cwd, 'val_ds.pt'))
x_val, y_val = val_ds[:]
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
mdvae = MDVAE(
n_data=n_data,
n_hidden=n_hidden,
n_latent=n_latent,
n_hidden_layers=n_hidden_layers,
)
loss_f = nn.MSELoss()
optimizer = torch.optim.Adam(mdvae.parameters(), lr=learning_rate, weight_decay=weight_decay)
if show_progress:
raise NotImplementedError
for epoch in range(n_epoch):
for i, (x, y) in enumerate(train_loader):
x_hat, mean, log_var = mdvae.forward(x)
reconstruct = loss_f(x_hat, y)
KL_div = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
loss = reconstruct + KL_factor * KL_div
optimizer.zero_grad()
if ~(torch.isnan(loss) | torch.isinf(loss)):
loss.backward()
optimizer.step()
with torch.no_grad():
x_val_hat, mean, log_var = mdvae.forward(x_val)
reconstruct_val = loss_f(x_val_hat, y_val)
KL_div_val = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
val_loss = reconstruct_val + KL_factor * KL_div_val
print(_PRR(val_loss))
return mdvae
def prepare_datasets(
cwd,
n_data,
n_samples=20000,
val_frac=0.1
):
# reproduciple random numbers
rng = torch.Generator()
rng.manual_seed(3)
np_rng = np.random.RandomState(3)
# create a data-set with correlated variables
eigs = np_rng.rand(n_data)
eigs = eigs * (n_data / sum(eigs))
corr = torch.from_numpy(
random_correlation.rvs(eigs, random_state=np_rng)
)
chol = torch.linalg.cholesky(corr)
data = torch.randn((n_samples, n_data), generator=rng, dtype=torch.float64)
data = data @ chol # correlated data
data = (data - torch.mean(data, 0)) / torch.std(data, 0) # normalized data
dataset = VAE_Dataset(data.float())
n_validate = math.ceil(val_frac * n_samples)
# prepare datasets
train_ds, val_ds = random_split(
dataset,
lengths=(len(dataset) - n_validate, n_validate),
)
torch.save(train_ds, os.path.join(cwd, 'train_ds.pt'))
torch.save(val_ds, os.path.join(cwd, 'val_ds.pt'))
if name == "main":
param_space = {
'batch_size': 32,
'n_hidden': .8,
'n_latent': 10,
'n_hidden_layers': 4,
'learning_rate': 0.0001,
'KL_factor': 0.1,
}
n_samples = 20000
n_data = 15
n_latent = 5 # I make an assumption on the right latent space size, so this is kept fixed
cwd = os.getcwd()
if not (os.path.isfile(os.path.join(cwd, 'train_ds.pt')) and os.path.isfile(os.path.join(cwd, 'val_ds.pt'))):
prepare_datasets(cwd, n_data, n_samples)
mdvae = ray_train(param_space, cwd=os.getcwd(), n_data=15)
cwd = os.getcwd()
train_ds = torch.load(os.path.join(cwd, 'train_ds.pt'))
val_ds = torch.load(os.path.join(cwd, 'val_ds.pt'))
x_val, y_val = val_ds[:]
y_hat, mean, log_var = mdvae.forward(x_val)
print(y_val[:3])
print(y_hat[:3])
print(torch.std(y_hat, 0))
THe output of this is:
3.098008632659912
1.006637454032898
1.000028371810913
0.9986492991447449
0.9981059432029724
tensor([[ 1.8516, -0.3478, 2.6511, 1.9317, -0.1876, 0.4041, 1.6015, 0.3717,
0.8461, -1.1814, -1.4392, -0.9345, -1.1823, 0.7851, 0.2900],
[-0.6445, -0.0632, -1.0273, 0.5480, 1.1906, 0.5398, 0.0707, 0.2530,
0.5554, 2.5385, 0.0644, -0.0644, -1.5081, 0.3629, -1.9816],
[-0.8815, 0.1168, -0.9882, 0.7145, -0.0372, 1.1363, 0.3081, -0.6322,
0.8425, 0.1894, -0.4679, 0.3413, 0.5822, 1.1235, -0.9265]])
tensor([[-0.0023, -0.0152, 0.0029, 0.0012, -0.0078, -0.0040, -0.0068, 0.0090,
-0.0128, -0.0080, 0.0216, 0.0035, 0.0050, -0.0190, -0.0067],
[-0.0005, 0.0043, -0.0106, 0.0028, -0.0019, -0.0043, -0.0063, -0.0039,
-0.0142, 0.0004, 0.0255, -0.0042, -0.0013, -0.0173, -0.0056],
[-0.0026, 0.0002, -0.0004, -0.0010, 0.0057, -0.0078, 0.0026, -0.0010,
-0.0060, -0.0050, 0.0206, 0.0029, 0.0008, -0.0187, -0.0017]])
tensor([0.0024, 0.0102, 0.0073, 0.0030, 0.0120, 0.0046, 0.0099, 0.0062, 0.0089,
0.0048, 0.0042, 0.0051, 0.0034, 0.0009, 0.0072])
My observations
The floats at the top are the validation loss and clearly it goes down and stabilizes around 1. From this I conclude that we are 'learning' something. If we then look at printed tensors, the first one is not at all like the second one and so the reconstruction does not work at all. Also note that the reconstruction tensor almost only contains values close to 0. The standard deviation of these reconstructed values is shown in the third tensor. I interpret these problems as posterior collapse.
For reference
the non-standard hyper parameter meanings are the following:
n_hidden: fraction of the input space dimensions that is used as the dimension of the hidden layersn_latent: latent space dimensionn_hidden_layers: number of hidden layers; thinking is that maybe the network is not complex enough?KL_factor: a factor to reduce the weight of the KL divergence in the loss; this is known as a $\beta$-VAE
The data in this example are drawn from a correlated multivariate normal distribution (see prepare_datasets), my real data is far more complex and high dimensional.
Previous tries
Things I tried that led to the same issue of having reconstructions that are vectors of 0s and a loss of ~ 1:
- I normalize my real data by subtracting the mean and dividing by standard deviation; also the case for the
prepare_datafunction in this example - I tried having
n_data = n_hidden = n_latentmeaning that there is no bottleneck; this also lead to posterior collapse - ray tune parameter search space:
param_space = {
'batch_size': tune.choice(list(2 ** np.arange(3, 9))),
'n_hidden': tune.uniform(min_n_hidden, max_n_hidden),
'n_latent': n_latent,
'n_hidden_layers': tune.randint(1, 4),
'learning_rate': tune.loguniform(1e-5, 3e-1),
'KL_factor': tune.uniform(1e-5, 1),
}
None of the combinations in this space lead to a good reconstruction and all outputs have been just vectors with numbers close to 0.
Questions
I am now at a loss; to me it seems like the VAE code is fine and yet none of the things I have read about posterior collapse have lead to a working VAE even on example data.
- Why do I keep getting 0 vectors?
- If the network is learning (loss going down and stabilizing around 1), why am I not getting a working VAE?
- What should be my next steps in creating a working VAE?
- The VAE is attractive since the latent space samples resemble a standard multivariate normal (MVN). For my project, these latent variables serve as independent variables in a conditional density estimation task and this is made easier if they are MVN. Are there perhaps other algorithms (than VAE) that I should look at to learn latent variable representations of my data?
- Are there other encoder/decoder architectures that make sense for continuous tabular data? I now just use fully connected layers but maybe something more fancy would do the trick?