I am trying to train a variational auto-encoder where x ≈ f_VAE(x) = x_hat. In my real problem, I have 100-400 dimensional data that I would like to compress to around 30 latent variables. For now I have a toy problem where I have 14 dimensional data that I would like to compress to 7 dimensions. All dimensions are real numbers and have been standardized (zero mean, unit standard deviation). To give some idea of the data here is a scatter plot:

I train the VAE by minimizing the following loss loss = MSE + beta * KL_div, where MSE is the mean-squared-error, KL_div is the KL-divergence term and beta is a scaling factor. I keep running into the posterior collapse problem. In brief, the MSE ≈ 1.0 and does not drop lower no matter how long you train and what model architecture you use. The KL_div ≈ 0.0, thus vanishing and the reconstructed features x_hat is a vector of values around 0 (mean of the standardized data) regardless of input x / latent variables z. Posterior collapse has been described in the literature and there are several proposed solutions such as beta-VAE [2] and cyclical annealing VAE [3]. For my problem, I have tried the following:
- normal VAE (beta = 1)
- beta VAE (beta < 1)
- linear annealing with warmup (
beta = 0for 4 epochs and then rises to 1 over the other 16 epochs and is incremented every mini-batch) - cyclical annealing VAE with warmup (
beta = 0for 4 epochs and then goes through 4 cycles between 0 and 1 with increments every mini-batch)
At the end of the day, none of these strategies work. During warmup, when beta = 0, the MSE goes down to around 0.2, actually yielding decent reconstructions x_hat. As soon as the KL-regularization kicks in, the network stops caring about the MSE reconstruction error and the KL_div vanishes. Using ray Tune, I tried many different VAE architectures with differing numbers of hidden layers, different sizes of latent and hidden layers and other learning parameters. One interesting observation is that training a normal auto-encoder (AE, where beta = 0 for 20 epochs) does seem to work with decent reconstructions. The AE even performs somewhat OK on the validation data.
Questions in decreasing order of importance:
- What else can I try to make VAEs work for my problem
- Why is is possible to train an AE for this data, but as soon as you introduce some KL-div regularization, the whole task becomes impossible
- What could possible explanations be for the KL regularization leading to posterior collapse
- What other machine learning approaches exist for latent variable learning that I could look into