3

I am trying to accomplish the reverse of the typical MNIST in machine learning using a GAN - instead of predicting a number from an image of a digit, I want to reconstruct an image of a digit from a number. The traditional GAN, however, isn't designed for this use case, as it is designed to generate images similar to training data directly without being given an input. One way to work around this issue that I've thought of is to take a train feature digit, connect it to a densely-connected layer Dense(784), reshape it to (28 x 28 x 1), and then proceed with the generator as one usually does for a GAN. However, this seems like "fooling" the neural network to make up weights out of thin air, and I doubt this would work properly.

How can I modify a GAN so that it takes single-digit inputs without resorting to the aforementioned approach?

Neil Slater
  • 33,739
  • 3
  • 47
  • 66
JS4137
  • 143
  • 4

2 Answers2

4

This is a common use case for GANs, you want the output to be conditioned on some controlled input, as opposed to just random seed data.

This Medium article cGAN: Conditional Generative Adversarial Network — How to Gain Control Over GAN Outputs walks through creating your example almost exactly (using MNIST digits and selecting which one you want).

The basic changes you need to make to turn your "freely choose class" GAN to a conditional GAN:

  • Add input to the generator, in addition to the random seed, that describes the class you want to generate. For a digit generator, that might be a one-hot array for the digit class you want it to make.

  • Add the same input to the discriminator, alongside the image input. That may involve having inputs at different layers, so you can combine CNN and fully-connect layers more easily. Typically you would concatenate the class choice to the flattened last CNN layer, and use this concatenated vector as input to the first fully-connected layer. But you could concat the class data to any nn layer before the output.

  • Train as before, whilst tracking which class is being faked (or real) and ensuring the generator and discriminator are fed the correct class details during training.

You could optionally provide incorrect labels for some real or fake images, and score appropriately, but that is not 100% necessary. The discriminator should detect that something declared to be a '1' looks more like a '2' and mark it as fake, without needing to be specifically trained for it.

Neil Slater
  • 33,739
  • 3
  • 47
  • 66
3

I agree with @Neil answer, as I also strongly believe that cGANs are the actual answer of your problem.

However, as he suggested, maybe it's worth mentioning that also GANs inversion can be used to do achieve such results, and it's used when training a new GAN from scratch it's too expensive.

Alberto
  • 2,863
  • 5
  • 12