Move your mouse over the graph to generate a digit.
How it works: I trained a variational autoencoder (VAE) on the MNIST dataset using a 2-dimensional latent space. The graph below shows each image in the training set encoded into its mean in the latent space, shown as an xy point, and colored with its respective label. I trained the model in Google Colab with pytorch, and exported the decoder to ONNX format, then used ONNX.js to load the model and run it in the browser.
The point that you move the mouse over is interpreted as a sample from the latent space. It is passed into the decoder to get an image, which is then drawn on a canvas to the right.
While exploring the digit clusters in the latent space above, I noticed that they each capture similar characteristics of a particular digit, like rotation and scaling. I was curious if the latent distribution could learn to capture these characteristics in a general way across all digits.
I tweaked the autoencoder by appending a 10-dimensional onehot vector representing the digit class to the 2 sampled latents before passing them into the decoder. My theory was that by explicitly providing the label to the decoder, the VAE would use the latent space to represent characterists that were label-invariant.
As shown below, this caused the VAE's latent space to become a mess of color, indicating that the digits are evenly distributed throughout, and that the VAE likely learned to rely on the provided label rather than attempting to cluster digits. In fact, this model was able to achieve a much better loss than the previous one. And it indeed used the latents to represent general characteristics like shearing (x-axis) and boldness (y-axis).
VAEs are distinct from regular autoencoders due to the fact that their latent representation is a probability distribution, not an arbitrary vector. As a result, the VAE loss function has two terms: the reconstruction loss (which is the same as a regular autoencoder: comparing the input to the output) and a term measuring the difference between the current and desired latent probability distribution, computed with KL divergence.
Typically, and in the experiments above, the desired latent distribution is just a standard gaussian with means being 0 and variance being a diagonal matrix with all 1s. This causes the latents to be pushed around the origin, with similar data being clustered together (as seen above) making it convenient to sample, and yielding realistic interpolations between samples. Using a gaussian also makes it possible to compute the KL divergence loss term in closed form.
I experimented with manipulating the latent distribution using the digit labels. Instead of computing the KL divergence loss with the same standard gaussian for each training image, I made the target means be different for each digit (which required rederiving the closed form KL divergence loss term to have an arbitrary target mean). As shown below, this worked suprisingly well: the distributions of each digit were successfully forced around distinct, predefined points in the latent space. Even more surprising is the fact that interpolation between digits is still fairly smooth.
In the experiment above, prescribing means for each digit distribution caused distinct digit clusters. Trying to have clearly distinguishable digit clusters is effectively a classification task. To see if this goal could be made more explicit, I experimented with training the decoder to jointly classify digits alongside outputting image samples.
In this experiement, the decoder outputs both a sample image as well as a 10-element vector for digit classification. Each iteration of the training loop, the loss is computed by adding the VAE loss terms to a weighted classification loss. I had to tweak the weighting a bit before noticing an affect.
The classification loss causes the encoder to separate digits in the latent space because the decoder must rely entirely on latent samples to classify the digits. The result are digit clusters with much cleaner boundaries than the vanilla VAE, while still being distributed around the origin. However, another effect I've observed is that interpolation becomes somewhat less realistic, since there's unused gaps between digit clusters. The size of the gaps depends on the weight of the classification loss term.