1

I know I can make a VAE do generation with a mean of 0 and std-dev of 1. I tested it with the following loss function:

def loss(self, data, reconst, mu, sig):
    rl = self.reconLoss(reconst, data)
    #dl = self.divergenceLoss(mu, sig)
    std = torch.exp(0.5 * sig)
    compMeans = torch.full(std.size(), 0.0)
    compStd = torch.full(std.size(), 1.0)
    dl = kld(mu, std, compMeans, compStd)
    totalLoss = self.rw * rl + self.dw * dl
    return (totalLoss, rl, dl)

def kld(mu1, std1, mu2, std2): p = torch.distributions.Normal(mu1, std1) q = torch.distributions.Normal(mu2, std2) return torch.distributions.kl_divergence(p, q).mean()

In this case, mu and sig are from the latent vector, and reconLoss is MSE. This works well, and I am able to generate MNIST digits by feeding in noise from a standard normal distribution.

However, I'd now like to concentrate the distribution at a normal distribution with std-dev of 1 and mean of 10. I tried changing it like this:

compMeans = torch.full(std.size(), 10.0)

I did the same change in reparameterization and generation functions. But what worked for the standard normal distribution is not working for the mean = 10 normal one. Reconstruction still works fine but generation does not, only producing strange shapes. Oddly, the divergence loss is actually going down too, and reaching a similar level to what it reached with standard normal.

Does anyone know why this isn't working? Is there something about KL that does not work with non-standard distributions?

Other things I've tried:

  • Generating from 0,1 after training on 10,1: failed
  • Generating on -10,1 after training on 10,1: failed
  • Custom version of KL divergence: worked on 0,1. failed on 10,1
  • Using sigma directly instead of std = torch.exp(0.5 * sig): failed

Edit 1: Below are my loss plots with 0,1 distribution. Reconstruction: reconst-loss

Divergence: enter image description here

Generation samples: enter image description here

Reconstruction samples (left is input, right is output): enter image description here

And here are the plots for 10,1 normal distribution.

Reconstruction: enter image description here

Divergence: enter image description here

Generation sample: enter image description here

Note: when I ran it this time, it actually seemed to learn the generation a bit, though it's still printing mostly 8's or things that are nearly an 8 by structure. This is not the case for the standard normal distribution. The only difference from last run is the random seed.

Reconstruction sample: enter image description here

Sampled latent:

tensor([[ 9.6411,  9.9796,  9.9829, 10.0024,  9.6115,  9.9056,  9.9095, 10.0684,
         10.0435,  9.9308],
        [ 9.8364, 10.0890,  9.8836, 10.0544,  9.4017, 10.0457, 10.0134,  9.9539,
         10.0986, 10.0434],
        [ 9.9301,  9.9534, 10.0042, 10.1110,  9.8654,  9.4630, 10.0256,  9.9237,
          9.8614,  9.7408],
        [ 9.3332, 10.1289, 10.0212,  9.7660,  9.7731,  9.9771,  9.8550, 10.0152,
          9.9879, 10.1816],
        [10.0605,  9.8872, 10.0057,  9.6858,  9.9998,  9.4429,  9.8378, 10.0389,
          9.9264,  9.8789],
        [10.0931,  9.9347, 10.0870,  9.9941, 10.0001, 10.1102,  9.8260, 10.1521,
          9.9961, 10.0989],
        [ 9.5413,  9.8965,  9.2484,  9.7604,  9.9095,  9.8409,  9.3402,  9.8552,
          9.7309,  9.7300],
        [10.0113,  9.5318,  9.9867,  9.6139,  9.9422, 10.1269,  9.9375,  9.9242,
          9.9532,  9.9053],
        [ 9.8866, 10.1696,  9.9437, 10.0858,  9.5781, 10.1011,  9.8957,  9.9684,
          9.9904,  9.9017],
        [ 9.6977, 10.0545, 10.0383,  9.9647,  9.9738,  9.9795,  9.9165, 10.0705,
          9.9072,  9.9659],
        [ 9.6819, 10.0224, 10.0547,  9.9457,  9.9592,  9.9380,  9.8731, 10.0825,
          9.8949, 10.0187],
        [ 9.6339,  9.9985,  9.7757,  9.4039,  9.7309,  9.8588,  9.7938,  9.8712,
          9.9763, 10.0186],
        [ 9.7688, 10.0575, 10.0515, 10.0153,  9.9782, 10.0115,  9.9269, 10.1228,
          9.9738, 10.0615],
        [ 9.8575,  9.8241,  9.9603, 10.0220,  9.9342,  9.9557, 10.1162, 10.0428,
         10.1363, 10.3070],
        [ 9.6856,  9.7924,  9.9174,  9.5064,  9.8072,  9.7176,  9.7449,  9.7004,
          9.8268,  9.9878],
        [ 9.8630, 10.0470, 10.0227,  9.7871, 10.0410,  9.9470, 10.0638, 10.1259,
         10.1669, 10.1097]])

Note, this does seem to be in the right distribution.

Just in case, here's my reparameterization method too. Currently with 10,1 distribution:

def reparamaterize(self, mu, sig):
        std = torch.exp(0.5 * sig)
        epsMeans = torch.full(std.size(), 10.0)
        epsStd = torch.full(std.size(), 1.0)
        eps = torch.normal(epsMeans, epsStd)
        return eps * std + mu
axon
  • 53
  • 5

0 Answers0