2

I am trying to find a low-dimensional latent space representation for a bunch of simulated data. No matter what VAE architecture I try and no matter how I tweak it, the output of the VAE is always the same, where all values hover around 0; it does not reconstruct the input at all. As far as I can understand, the problem is known as posterior collapse in the literature.

I use ray.tune to find hyper-parameters that work and the lowest validation loss I ever get is val_loss ~ 1. I picked one hyper parameter configuration that yields a val_loss ~ 1 and used it to train a VAE in the script below. The data in this example is 15 dimensional and drawn from a correlated multivariate normal distribution.

import math
import os
import numpy as np
from scipy.stats import random_correlation, loguniform

import torch from torch.utils.data import Dataset, DataLoader, random_split from torch import nn

import time

class MDVAE(nn.Module): def init( self, n_data: int, n_latent = 0.3, n_hidden = 0.6, n_hidden_layers = 1, bias = True, # activation=nn.LeakyReLU(0.01), activation=nn.LeakyReLU(0.01), ): super().init() if isinstance(n_latent, float): n_latent = math.ceil(n_latent * n_data) if isinstance(n_hidden, float): n_hidden = math.ceil(n_hidden * n_data)

    if (n_latent > n_hidden) or (n_latent > n_data):
        raise ValueError

    endocer_layers = [nn.Linear(n_data, n_hidden, bias=bias), activation]
    for i in range(n_hidden_layers):
        endocer_layers.extend([nn.Linear(n_hidden, n_hidden, bias=bias), activation])

    self.encoder = nn.Sequential(*endocer_layers)
    self.mean_lay = nn.Linear(n_hidden, n_latent)
    self.log_var_lay  = nn.Linear(n_hidden, n_latent)

    decoder_layers = []
    for layer in endocer_layers:
        if layer == activation:
            continue
        decoder_layers.extend([nn.Linear(layer.out_features, layer.in_features), activation])
    decoder_layers.extend([nn.Linear(n_latent, n_hidden)])
    self.decoder = nn.Sequential(*decoder_layers[::-1])

def encode_and_sample(self, x):
    hidden = self.encoder(x)
    mean = self.mean_lay(hidden)
    log_var = self.log_var_lay(hidden)

    # reparametrization trick
    std = torch.exp(0.5 * log_var)
    epsilon = torch.randn_like(std)
    z = mean + std * epsilon
    return z, mean, log_var

def forward(self, x):
    z, mean, log_var = self.encode_and_sample(x)
    return self.decoder(z), mean, log_var

class VAE_Dataset(Dataset): def init(self, data: torch.Tensor): self.data = data

def __len__(self):
    return self.data.shape[0]

def __getitem__(self, idx):
    return self.data[idx], self.data[idx]


_PRR = lambda x: float(x.to('cpu').data.numpy()) def ray_train(config, cwd, n_data, n_epoch=5, show_progress=False): n_hidden = float(config.get('n_hidden', 0.6)) n_latent = int(config.get('n_latent', 0.3)) n_hidden_layers = int(config.get('n_hidden_layers', 2)) batch_size = int(config.get('batch_size', 64)) learning_rate = float(config.get('learning_rate', 1e-3)) weight_decay = float(config.get('weight_decay', 1e-4)) KL_factor = float(config.get('KL_factor', 1.0))

train_ds = torch.load(os.path.join(cwd, 'train_ds.pt'))
val_ds = torch.load(os.path.join(cwd, 'val_ds.pt'))
x_val, y_val = val_ds[:]
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

mdvae = MDVAE(
    n_data=n_data,
    n_hidden=n_hidden,
    n_latent=n_latent,
    n_hidden_layers=n_hidden_layers,
)

loss_f = nn.MSELoss()
optimizer = torch.optim.Adam(mdvae.parameters(), lr=learning_rate, weight_decay=weight_decay)

if show_progress:
    raise NotImplementedError

for epoch in range(n_epoch):
    for i, (x, y) in enumerate(train_loader):
        x_hat, mean, log_var = mdvae.forward(x)
        reconstruct = loss_f(x_hat, y)
        KL_div = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        loss = reconstruct + KL_factor * KL_div
        optimizer.zero_grad()
        if ~(torch.isnan(loss) | torch.isinf(loss)):
            loss.backward()
            optimizer.step()

    with torch.no_grad():
        x_val_hat, mean, log_var = mdvae.forward(x_val)
        reconstruct_val = loss_f(x_val_hat, y_val)
        KL_div_val = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        val_loss = reconstruct_val + KL_factor * KL_div_val

        print(_PRR(val_loss))
return mdvae


def prepare_datasets( cwd, n_data, n_samples=20000, val_frac=0.1 ): # reproduciple random numbers rng = torch.Generator() rng.manual_seed(3) np_rng = np.random.RandomState(3)

# create a data-set with correlated variables
eigs = np_rng.rand(n_data)
eigs = eigs * (n_data / sum(eigs))
corr = torch.from_numpy(
    random_correlation.rvs(eigs, random_state=np_rng)
)
chol = torch.linalg.cholesky(corr)

data = torch.randn((n_samples, n_data), generator=rng, dtype=torch.float64)
data = data @ chol  # correlated data
data = (data - torch.mean(data, 0)) / torch.std(data, 0)  # normalized data
dataset = VAE_Dataset(data.float())

n_validate = math.ceil(val_frac * n_samples)

# prepare datasets
train_ds, val_ds = random_split(
    dataset,
    lengths=(len(dataset) - n_validate, n_validate),
)
torch.save(train_ds, os.path.join(cwd, 'train_ds.pt'))
torch.save(val_ds, os.path.join(cwd, 'val_ds.pt'))


if name == "main": param_space = { 'batch_size': 32, 'n_hidden': .8, 'n_latent': 10, 'n_hidden_layers': 4, 'learning_rate': 0.0001, 'KL_factor': 0.1, }

n_samples = 20000
n_data = 15
n_latent = 5  # I make an assumption on the right latent space size, so this is kept fixed

cwd = os.getcwd()
if not (os.path.isfile(os.path.join(cwd, 'train_ds.pt')) and os.path.isfile(os.path.join(cwd, 'val_ds.pt'))):
    prepare_datasets(cwd, n_data, n_samples)

mdvae = ray_train(param_space, cwd=os.getcwd(), n_data=15)

cwd = os.getcwd()
train_ds = torch.load(os.path.join(cwd, 'train_ds.pt'))
val_ds = torch.load(os.path.join(cwd, 'val_ds.pt'))
x_val, y_val = val_ds[:]

y_hat, mean, log_var = mdvae.forward(x_val)

print(y_val[:3])
print(y_hat[:3])
print(torch.std(y_hat, 0))

THe output of this is:

3.098008632659912
1.006637454032898
1.000028371810913
0.9986492991447449
0.9981059432029724
tensor([[ 1.8516, -0.3478,  2.6511,  1.9317, -0.1876,  0.4041,  1.6015,  0.3717,
          0.8461, -1.1814, -1.4392, -0.9345, -1.1823,  0.7851,  0.2900],
        [-0.6445, -0.0632, -1.0273,  0.5480,  1.1906,  0.5398,  0.0707,  0.2530,
          0.5554,  2.5385,  0.0644, -0.0644, -1.5081,  0.3629, -1.9816],
        [-0.8815,  0.1168, -0.9882,  0.7145, -0.0372,  1.1363,  0.3081, -0.6322,
          0.8425,  0.1894, -0.4679,  0.3413,  0.5822,  1.1235, -0.9265]])
tensor([[-0.0023, -0.0152,  0.0029,  0.0012, -0.0078, -0.0040, -0.0068,  0.0090,
         -0.0128, -0.0080,  0.0216,  0.0035,  0.0050, -0.0190, -0.0067],
        [-0.0005,  0.0043, -0.0106,  0.0028, -0.0019, -0.0043, -0.0063, -0.0039,
         -0.0142,  0.0004,  0.0255, -0.0042, -0.0013, -0.0173, -0.0056],
        [-0.0026,  0.0002, -0.0004, -0.0010,  0.0057, -0.0078,  0.0026, -0.0010,
         -0.0060, -0.0050,  0.0206,  0.0029,  0.0008, -0.0187, -0.0017]])
tensor([0.0024, 0.0102, 0.0073, 0.0030, 0.0120, 0.0046, 0.0099, 0.0062, 0.0089,
        0.0048, 0.0042, 0.0051, 0.0034, 0.0009, 0.0072])

My observations

The floats at the top are the validation loss and clearly it goes down and stabilizes around 1. From this I conclude that we are 'learning' something. If we then look at printed tensors, the first one is not at all like the second one and so the reconstruction does not work at all. Also note that the reconstruction tensor almost only contains values close to 0. The standard deviation of these reconstructed values is shown in the third tensor. I interpret these problems as posterior collapse.

For reference

the non-standard hyper parameter meanings are the following:

  • n_hidden: fraction of the input space dimensions that is used as the dimension of the hidden layers
  • n_latent: latent space dimension
  • n_hidden_layers: number of hidden layers; thinking is that maybe the network is not complex enough?
  • KL_factor: a factor to reduce the weight of the KL divergence in the loss; this is known as a $\beta$-VAE

The data in this example are drawn from a correlated multivariate normal distribution (see prepare_datasets), my real data is far more complex and high dimensional.

Previous tries

Things I tried that led to the same issue of having reconstructions that are vectors of 0s and a loss of ~ 1:

  • I normalize my real data by subtracting the mean and dividing by standard deviation; also the case for the prepare_data function in this example
  • I tried having n_data = n_hidden = n_latent meaning that there is no bottleneck; this also lead to posterior collapse
  • ray tune parameter search space:
param_space = {
        'batch_size': tune.choice(list(2 ** np.arange(3, 9))),
        'n_hidden': tune.uniform(min_n_hidden, max_n_hidden),
        'n_latent': n_latent,
        'n_hidden_layers': tune.randint(1, 4),
        'learning_rate': tune.loguniform(1e-5, 3e-1),
        'KL_factor': tune.uniform(1e-5, 1),
    }

None of the combinations in this space lead to a good reconstruction and all outputs have been just vectors with numbers close to 0.

Questions

I am now at a loss; to me it seems like the VAE code is fine and yet none of the things I have read about posterior collapse have lead to a working VAE even on example data.

  • Why do I keep getting 0 vectors?
  • If the network is learning (loss going down and stabilizing around 1), why am I not getting a working VAE?
  • What should be my next steps in creating a working VAE?
  • The VAE is attractive since the latent space samples resemble a standard multivariate normal (MVN). For my project, these latent variables serve as independent variables in a conditional density estimation task and this is made easier if they are MVN. Are there perhaps other algorithms (than VAE) that I should look at to learn latent variable representations of my data?
  • Are there other encoder/decoder architectures that make sense for continuous tabular data? I now just use fully connected layers but maybe something more fancy would do the trick?
desertnaut
  • 1,021
  • 11
  • 19
Patrickens
  • 131
  • 2

0 Answers0