1

I am trying to train a variational auto-encoder where x ≈ f_VAE(x) = x_hat. In my real problem, I have 100-400 dimensional data that I would like to compress to around 30 latent variables. For now I have a toy problem where I have 14 dimensional data that I would like to compress to 7 dimensions. All dimensions are real numbers and have been standardized (zero mean, unit standard deviation). To give some idea of the data here is a scatter plot: enter image description here

I train the VAE by minimizing the following loss loss = MSE + beta * KL_div, where MSE is the mean-squared-error, KL_div is the KL-divergence term and beta is a scaling factor. I keep running into the posterior collapse problem. In brief, the MSE ≈ 1.0 and does not drop lower no matter how long you train and what model architecture you use. The KL_div ≈ 0.0, thus vanishing and the reconstructed features x_hat is a vector of values around 0 (mean of the standardized data) regardless of input x / latent variables z. Posterior collapse has been described in the literature and there are several proposed solutions such as beta-VAE [2] and cyclical annealing VAE [3]. For my problem, I have tried the following:

  • normal VAE (beta = 1)
  • beta VAE (beta < 1)
  • linear annealing with warmup (beta = 0 for 4 epochs and then rises to 1 over the other 16 epochs and is incremented every mini-batch)
  • cyclical annealing VAE with warmup (beta = 0 for 4 epochs and then goes through 4 cycles between 0 and 1 with increments every mini-batch)

At the end of the day, none of these strategies work. During warmup, when beta = 0, the MSE goes down to around 0.2, actually yielding decent reconstructions x_hat. As soon as the KL-regularization kicks in, the network stops caring about the MSE reconstruction error and the KL_div vanishes. Using ray Tune, I tried many different VAE architectures with differing numbers of hidden layers, different sizes of latent and hidden layers and other learning parameters. One interesting observation is that training a normal auto-encoder (AE, where beta = 0 for 20 epochs) does seem to work with decent reconstructions. The AE even performs somewhat OK on the validation data.

Questions in decreasing order of importance:

  • What else can I try to make VAEs work for my problem
  • Why is is possible to train an AE for this data, but as soon as you introduce some KL-div regularization, the whole task becomes impossible
  • What could possible explanations be for the KL regularization leading to posterior collapse
  • What other machine learning approaches exist for latent variable learning that I could look into
Patrickens
  • 131
  • 2

0 Answers0