Jekyll2024-01-04T22:43:14+00:00https://sander.ai/feed.xmlSander DielemanI write about machine learning, deep learning, music information retrieval, recommender systems, generative models and more.The geometry of diffusion guidance2023-08-28T00:00:00+01:002023-08-28T00:00:00+01:00https://sander.ai/2023/08/28/geometry<p>Guidance is a powerful method that can be used to enhance diffusion model sampling. As I’ve discussed in <a href="https://sander.ai/2022/05/26/guidance.html">an earlier blog post</a>, it’s almost like a cheat code: it can improve sample quality so much that it’s as if the model had ten times the number of parameters – an order of magnitude improvement, basically for free! This follow-up post provides a geometric interpretation and visualisation of the diffusion sampling procedure, which I’ve found particularly useful to explain how guidance works.</p>
<h2 id="-a-word-of-warning-about-high-dimensional-spaces"><a name="warning"> A word of warning about high-dimensional spaces</a></h2>
<figure>
<a href="/images/dimensions.jpg"><img src="/images/dimensions.jpg" /></a>
</figure>
<p>Sampling algorithms for diffusion models typically start by initialising a <em>canvas</em> with random noise, and then repeatedly updating this canvas based on model predictions, until a sample from the model distribution eventually emerges.</p>
<!-- TODO: image sequence starting from noise and then showing the real image at the end -->
<p>We will represent this canvas by a vector \(\mathbf{x}_t\), where \(t\) represents the current time step in the sampling procedure. By convention, the diffusion process which gradually corrupts inputs into random noise moves forward in time from \(t=0\) to \(t=T\), so the sampling procedure goes backward in time, from \(t=T\) to \(t=0\). Therefore \(\mathbf{x}_T\) corresponds to random noise, and \(\mathbf{x}_0\) corresponds to a sample from the data distribution.</p>
<p><strong>\(\mathbf{x}_t\) is a high-dimensional vector</strong>: for example, if a diffusion model produces images of size 64x64, there are 12,288 different scalar intensity values (3 colour channels per pixel). The sampling procedure then traces a path through a 12,288-dimensional Euclidean space.</p>
<p>It’s pretty difficult for the human brain to comprehend what that actually looks like in practice. Because our intuition is firmly rooted in our 3D surroundings, it actually tends to fail us in surprising ways in high-dimensional spaces. A while back, I wrote <a href="https://sander.ai/2020/09/01/typicality.html">a blog post</a> about some of the implications for high-dimensional probability distributions in particular. <a href="http://www.penzba.co.uk/cgi-bin/PvsNP.py?SpikeySpheres#HN2">This note about why high-dimensional spheres are “spikey”</a> is also worth a read, if you quickly want to get a feel for how weird things can get. A more thorough treatment of high-dimensional geometry can be found in chapter 2 of ‘Foundations of Data Science’<sup id="fnref:foundations" role="doc-noteref"><a href="#fn:foundations" class="footnote" rel="footnote">1</a></sup> by Blum, Hopcroft and Kannan, which is <a href="https://www.cs.cornell.edu/jeh/book.pdf">available to download in PDF format</a>.</p>
<p>Nevertheless, in this blog post, <strong>I will use diagrams that represent \(\mathbf{x}_t\) in <em>two</em> dimensions</strong>, because unfortunately that’s all the spatial dimensions available on your screen. <strong>This is dangerous</strong>: following our intuition in 2D might lead us to the wrong conclusions. But I’m going to do it anyway, because in spite of this, I’ve found these diagrams quite helpful to explain how manipulations such as guidance affect diffusion sampling in practice.</p>
<p>Here’s some advice from Geoff Hinton on dealing with high-dimensional spaces that may or may not help:</p>
<blockquote class="twitter-tweet"><p lang="en" dir="ltr">I'm laughing so hard at this slide a friend sent me from one of Geoff Hinton's courses;<br /><br />"To deal with hyper-planes in a 14-dimensional space, visualize a 3-D space and say 'fourteen' to yourself very loudly. Everyone does it." <a href="https://t.co/nTakZArbsD">pic.twitter.com/nTakZArbsD</a></p>— Robbie Barrat (@videodrome) <a href="https://twitter.com/videodrome/status/1005887240407379969?ref_src=twsrc%5Etfw">June 10, 2018</a></blockquote>
<script async="" src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
<p>… anyway, <strong>you’ve been warned!</strong></p>
<h2 id="-visualising-diffusion-sampling"><a name="sampling"></a> Visualising diffusion sampling</h2>
<figure>
<a href="/images/dice.jpg"><img src="/images/dice.jpg" /></a>
</figure>
<p>To start off, let’s visualise what a step of diffusion sampling typically looks like. I will use a real photograph to which I’ve added varying amounts of noise to stand in for intermediate samples in the diffusion sampling process:</p>
<figure>
<a href="/images/noisy_bundle_128.png"><img src="/images/noisy_bundle_128.png" alt="Bundle the bunny, with varying amounts of noise added." /></a>
<figcaption>Bundle the bunny, with varying amounts of noise added. <a href="https://twitter.com/kipperrii/status/1574557416741474304">Photo credit: kipply</a>.</figcaption>
</figure>
<p>During diffusion model training, examples of noisy images are produced by taking examples of clean images from the data distribution, and adding varying amounts of noise to them. This is what I’ve done above. During sampling, we start from a canvas that is pure noise, and then the model gradually removes random noise and replaces it with meaningful structure in accordance with the data distribution. Note that I will be using this set of images to represent intermediate samples from a model, even though that’s not how they were constructed. If the model is good enough, you shouldn’t be able to tell the difference anyway!</p>
<p>In the diagram below, we have an intermediate noisy sample \(\mathbf{x}_t\), somewhere in the middle of the sampling process, as well as the final output of that process \(\mathbf{x}_0\), which is noise-free:</p>
<figure>
<a href="/images/geometry_diagram001.png"><img src="/images/geometry_diagram001.png" style="border: 1px dotted #bbb;" alt="Diagram showing an intermediate noisy sample, as well as the final output of the sampling process." /></a>
<figcaption>Diagram showing an intermediate noisy sample, as well as the final output of the sampling process.</figcaption>
</figure>
<p>Imagine the two spatial dimensions of your screen representing just two of many thousands of pixel colour intensities (red, green or blue). Different spatial positions in the diagram correspond to different images. A single step in the sampling procedure is taken by using the model to <strong>predict where the final sample will end up</strong>. We’ll call this prediction \(\hat{\mathbf{x}}_0\):</p>
<figure>
<a href="/images/geometry_diagram002.png"><img src="/images/geometry_diagram002.png" style="border: 1px dotted #bbb;" alt="Diagram showing the prediction of the final sample from the current step in the sampling process." /></a>
<figcaption>Diagram showing the prediction of the final sample from the current step in the sampling process.</figcaption>
</figure>
<p>Note how this prediction is roughly in the direction of \(\mathbf{x}_0\), but we’re not able to predict \(\mathbf{x}_0\) exactly from the current point in the sampling process, \(\mathbf{x}_t\), because the noise obscures a lot of information (especially fine-grained details), which we aren’t able to fill in all in one go. Indeed, if we were, there would be no point to this iterative sampling procedure: we could just go directly from pure noise \(\mathbf{x}_T\) to a clean image \(\mathbf{x}_0\) in one step. (As an aside, this is more or less what Consistency Models<sup id="fnref:cm" role="doc-noteref"><a href="#fn:cm" class="footnote" rel="footnote">2</a></sup> try to achieve.)</p>
<p><strong>Diffusion models estimate the <em>expectation</em> of \(\mathbf{x}_0\)</strong>, given the current noisy input \(\mathbf{x}_t\): \(\hat{\mathbf{x}}_0 = \mathbb{E}[\mathbf{x}_0 \mid \mathbf{x}_t]\). At the highest noise levels, this expectation basically corresponds to the mean of the entire dataset, because very noisy inputs are not very informative. As a result, the prediction \(\hat{\mathbf{x}}_0\) will look like a very blurry image when visualised. At lower noise levels, this prediction will become sharper and sharper, and it will eventually resemble a sample from the data distribution. In <a href="https://sander.ai/2023/07/20/perspectives.html#expectation">a previous blog post</a>, I go into a little bit more detail about why diffusion models end up estimating expectations.</p>
<p>In practice, diffusion models are often parameterised to predict noise, rather than clean input, which I also discussed in <a href="https://sander.ai/2023/07/20/perspectives.html#conventions">the same blog post</a>. Some models also predict time-dependent linear combinations of the two. Long story short, all of these parameterisations are equivalent once the model has been trained, because a prediction of one of these quantities can be turned into a prediction of another through a linear combination of the prediction itself and the noisy input \(\mathbf{x}_t\). That’s why we can always get a prediction \(\hat{\mathbf{x}}_0\) out of any diffusion model, regardless of how it was parameterised or trained: for example, if the model predicts the noise, simply take the noisy input and subtract the predicted noise.</p>
<p>Diffusion model predictions also correspond to an estimate of the so-called <em>score function</em>, \(\nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)\). This can be interpreted as the direction in input space along which the log-likelihood of the input increases maximally. In other words, it’s the answer to the question: <strong>“how should I change the input to make it more likely?”</strong> Diffusion sampling now proceeds by <strong>taking a small step in the direction of this prediction</strong>:</p>
<figure>
<a href="/images/geometry_diagram003.png"><img src="/images/geometry_diagram003.png" style="border: 1px dotted #bbb;" alt="Diagram showing how we take a small step in the direction of the prediction of the final sample." /></a>
<figcaption>Diagram showing how we take a small step in the direction of the prediction of the final sample.</figcaption>
</figure>
<p>This should look familiar to any machine learning practitioner, as it’s very similar to neural network training via gradient descent: backpropagation gives us the direction of steepest descent at the current point in parameter space, and at each optimisation step, we take a small step in that direction. Taking a very large step wouldn’t get us anywhere interesting, because the estimated direction is only valid locally. The same is true for diffusion sampling, except we’re now operating in the input space, rather than in the space of model parameters.</p>
<p>What happens next depends on the specific sampling algorithm we’ve chosen to use. There are many to choose from: DDPM<sup id="fnref:ddpm" role="doc-noteref"><a href="#fn:ddpm" class="footnote" rel="footnote">3</a></sup> (also called ancestral sampling), DDIM<sup id="fnref:ddim" role="doc-noteref"><a href="#fn:ddim" class="footnote" rel="footnote">4</a></sup>, DPM++<sup id="fnref:dpmpp" role="doc-noteref"><a href="#fn:dpmpp" class="footnote" rel="footnote">5</a></sup> and ODE-based sampling<sup id="fnref:sde" role="doc-noteref"><a href="#fn:sde" class="footnote" rel="footnote">6</a></sup> (with many sub-variants using different ODE solvers) are just a few examples. Some of these algorithms are deterministic, which means the only source of randomness in the sampling procedure is the initial noise on the canvas. Others are stochastic, which means that further noise is injected at each step of the sampling procedure.</p>
<p>We’ll use DDPM as an example, because it is one of the oldest and most commonly used sampling algorithms for diffusion models. This is a stochastic algorithm, so <strong>some random noise is added after taking a step</strong> in the direction of the model prediction:</p>
<figure>
<a href="/images/geometry_diagram004.png"><img src="/images/geometry_diagram004.png" style="border: 1px dotted #bbb;" alt="Diagram showing how noise is added after taking small step in the direction of the model prediction." /></a>
<figcaption>Diagram showing how noise is added after taking small step in the direction of the model prediction.</figcaption>
</figure>
<p>Note that I am intentionally glossing over some of the details of the sampling algorithm here (for example, the exact variance of the noise \(\varepsilon\) that is added at each step). The diagrams are schematic and the focus is on building intuition, so I think I can get away with that, but obviously it’s pretty important to get this right when you actually want to implement this algorithm.</p>
<p>For deterministic sampling algorithms, we can simply skip this step (i.e. set \(\varepsilon = 0\)). After this, we end up in \(\mathbf{x}_{t-1}\), which is the next iterate in the sampling procedure, and should correspond to a slightly less noisy sample. <strong>To proceed, we rinse and repeat</strong>. We can again make a prediction \(\hat{\mathbf{x}}_0\):</p>
<figure>
<a href="/images/geometry_diagram005.png"><img src="/images/geometry_diagram005.png" style="border: 1px dotted #bbb;" alt="Diagram showing the updated prediction of the final sample from the current step in the sampling process." /></a>
<figcaption>Diagram showing the updated prediction of the final sample from the current step in the sampling process.</figcaption>
</figure>
<p>Because we are in a different point in input space, this prediction will also be different. Concretely, <strong>as the input to the model is now slightly less noisy, the prediction will be slightly less blurry</strong>. We now take a small step in the direction of this new prediction, and add noise to end up in \(\mathbf{x}_{t-2}\):</p>
<figure>
<a href="/images/geometry_diagram006.png"><img src="/images/geometry_diagram006.png" style="border: 1px dotted #bbb;" alt="Diagram showing a sequence of two DDPM sampling steepest." /></a>
<figcaption>Diagram showing a sequence of two DDPM sampling steps.</figcaption>
</figure>
<p>We can keep doing this until we eventually reach \(\mathbf{x}_0\), and we will have drawn a sample from the diffusion model. To summarise, below is an animated version of the above set of diagrams, showing the sequence of steps:</p>
<figure>
<a href="/images/geometry_diagram007.gif"><img src="/images/geometry_diagram007.gif" style="border: 1px dotted #c10;" alt="Animation of the above set of diagrams." /></a>
<figcaption>Animation of the above set of diagrams.</figcaption>
</figure>
<h2 id="-classifier-guidance"><a name="classifier-guidance"></a> Classifier guidance</h2>
<figure>
<a href="/images/sorted.jpg"><img src="/images/sorted.jpg" /></a>
</figure>
<p>Classifier guidance<sup id="fnref:sde:1" role="doc-noteref"><a href="#fn:sde" class="footnote" rel="footnote">6</a></sup> <sup id="fnref:equilibrium" role="doc-noteref"><a href="#fn:equilibrium" class="footnote" rel="footnote">7</a></sup> <sup id="fnref:beatgans" role="doc-noteref"><a href="#fn:beatgans" class="footnote" rel="footnote">8</a></sup> provides a way to <strong>steer diffusion sampling in the direction that maximises the probability of the final sample being classified as a particular class</strong>. More broadly, this can be used to make the sample reflect any sort of conditioning signal that wasn’t provided to the diffusion model during training. In other words, it enables <em>post-hoc</em> conditioning.</p>
<p>For classifier guidance, we need an auxiliary model that predicts \(p(y \mid \mathbf{x})\), where \(y\) represents an arbitrary input feature, which could be a class label, a textual description of the input, or even a more structured object like a segmentation map or a depth map. We’ll call this model a <em>classifier</em>, but keep in mind that we can use many different kinds of models for this purpose, not just classifiers in the narrow sense of the word. What’s nice about this setup, is that such models are usually smaller and easier to train than diffusion models.</p>
<p>One important caveat is that we will be applying this auxiliary model to <em>noisy</em> inputs \(\mathbf{x}_t\), at varying levels of noise, so it has to be robust against this particular type of input distortion. This seems to preclude the use of off-the-shelf classifiers, and implies that we need to train a custom noise-robust classifier, or at the very least, fine-tune an off-the-shelf classifier to be noise-robust. We can also explicitly condition the classifier on the time step \(t\), so the level of noise does not have to be inferred from the input \(\mathbf{x}_t\) alone.</p>
<p>However, it turns out that we can construct a reasonable noise-robust classifier by combining an off-the-shelf classifier (which expects noise-free inputs) with our diffusion model. Rather than applying the classifier to \(\mathbf{x}_t\), we first predict \(\hat{\mathbf{x}}_0\) with the diffusion model, and use that as input to the classifier instead. \(\hat{\mathbf{x}}_0\) is still distorted, but by blurring rather than by Gaussian noise. Off-the-shelf classifiers tend to be much more robust to this kind of distortion out of the box. Bansal et al.<sup id="fnref:universal" role="doc-noteref"><a href="#fn:universal" class="footnote" rel="footnote">9</a></sup> named this trick “forward universal guidance”, though it has been known for some time. They also suggest some more advanced approaches for post-hoc guidance.</p>
<p>Using the classifier, we can now determine the direction in input space that maximises the log-likelihood of the conditioning signal, simply by computing <strong>the gradient with respect to \(\mathbf{x}_t\)</strong>: \(\nabla_{\mathbf{x}_t} \log p(y \mid \mathbf{x}_t)\). (Note: if we used the above trick to construct a noise-robust classifier from an off-the-shelf one, this means we’ll need to backpropagate through the diffusion model as well.)</p>
<figure>
<a href="/images/geometry_diagram008.png"><img src="/images/geometry_diagram008.png" style="border: 1px dotted #bbb;" alt="Diagram showing the update directions from the diffusion model and the classifier." /></a>
<figcaption>Diagram showing the update directions from the diffusion model and the classifier.</figcaption>
</figure>
<p>To apply classifier guidance, we <strong>combine the directions obtained from the diffusion model and from the classifier by adding them together</strong>, and then we take a step in this combined direction instead:</p>
<figure>
<a href="/images/geometry_diagram009.png"><img src="/images/geometry_diagram009.png" style="border: 1px dotted #bbb;" alt="Diagram showing the combined update direction for classifier guidance." /></a>
<figcaption>Diagram showing the combined update direction for classifier guidance.</figcaption>
</figure>
<p>As a result, the sampling procedure will trace a different trajectory through the input space. To control the influence of the conditioning signal on the sampling procedure, we can <strong>scale the contribution of the classifier gradient by a factor \(\gamma\)</strong>, which is called the <em>guidance scale</em>:</p>
<figure>
<a href="/images/geometry_diagram011.png"><img src="/images/geometry_diagram010.png" style="border: 1px dotted #bbb;" alt="Diagram showing the scaled classifier update direction." /></a>
<figcaption>Diagram showing the scaled classifier update direction.</figcaption>
</figure>
<p>The combined update direction will then be influenced more strongly by the direction obtained from the classifier (provided that \(\gamma > 1\), which is usually the case):</p>
<figure>
<a href="/images/geometry_diagram011.png"><img src="/images/geometry_diagram011.png" style="border: 1px dotted #bbb;" alt="Diagram showing the combined update direction for classifier guidance with guidance scale." /></a>
<figcaption>Diagram showing the combined update direction for classifier guidance with guidance scale.</figcaption>
</figure>
<p>This scale factor \(\gamma\) is an important sampling hyperparameter: if it’s too low, the effect is negligible. If it’s too high, the samples will be distorted and low-quality. This is because <strong>gradients obtained from classifiers don’t necessarily point in directions that lie on the image manifold</strong> – if we’re not careful, we may actually end up in adversarial examples, which maximise the probability of the class label but don’t actually look like an example of the class at all!</p>
<p>In <a href="https://sander.ai/2022/05/26/guidance.html">my previous blog post on diffusion guidance</a>, I made the connection between these operations on vectors in the input space, and the underlying manipulations of distributions they correspond to. It’s worth briefly revisiting this connection to make it more apparent:</p>
<ul>
<li>
<p>We’ve taken the update direction obtained from the diffusion model, which corresponds to \(\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t)\) (i.e. the score function), and the (scaled) update direction obtained from the classifier, \(\gamma \cdot \nabla_{\mathbf{x}_t} \log p(y \mid \mathbf{x}_t)\), and combined them simply by adding them together: \(\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t) + \gamma \cdot \nabla_{\mathbf{x}_t} \log p(y \mid \mathbf{x}_t)\).</p>
</li>
<li>
<p>This expression corresponds to the gradient of the logarithm of \(p_t(\mathbf{x}_t) \cdot p(y \mid \mathbf{x}_t)^\gamma\).</p>
</li>
<li>
<p>In other words, we have effectively <em>reweighted</em> the model distribution, changing the probability of each input in accordance with the probability the classifier assigns to the desired class label.</p>
</li>
<li>
<p>The guidance scale \(\gamma\) corresponds to the <em>temperature</em> of the classifier distribution. A high temperature implies that inputs to which the classifier assigns high probabilities are upweighted more aggressively, relative to other inputs.</p>
</li>
<li>
<p>The result is a new model that is much more likely to produce samples that align with the desired class label.</p>
</li>
</ul>
<p>An animated diagram of a single step of sampling with classifier guidance is shown below:</p>
<figure>
<a href="/images/geometry_diagram018.gif"><img src="/images/geometry_diagram018.gif" style="border: 1px dotted #c10;" alt="Animation of a single step of sampling with classifier guidance." /></a>
<figcaption>Animation of a single step of sampling with classifier guidance.</figcaption>
</figure>
<h2 id="-classifier-free-guidance"><a name="classifier-free-guidance"></a> Classifier-free guidance</h2>
<figure>
<a href="/images/winding_road.jpg"><img src="/images/winding_road.jpg" /></a>
</figure>
<p>Classifier-free guidance<sup id="fnref:cf" role="doc-noteref"><a href="#fn:cf" class="footnote" rel="footnote">10</a></sup> is a variant of guidance that does not require an auxiliary classifier model. Instead, <strong>a Bayesian classifier is constructed by combining a conditional and an unconditional generative model</strong>.</p>
<p>Concretely, when training a conditional generative model \(p(\mathbf{x}\mid y)\), we can drop out the conditioning \(y\) some percentage of the time (usually 10-20%) so that the same model can also act as an unconditional generative model, \(p(\mathbf{x})\). It turns out that this does not have a detrimental effect on conditional modelling performance. Using Bayes’ rule, we find that \(p(y \mid \mathbf{x}) \propto \frac{p(\mathbf{x}\mid y)}{p(\mathbf{x})}\), which gives us a way to turn our generative model into a classifier.</p>
<p>In diffusion models, we tend to express this in terms of score functions, rather than in terms of probability distributions. Taking the logarithm and then the gradient w.r.t. \(\mathbf{x}\), we get \(\nabla_\mathbf{x} \log p(y \mid \mathbf{x}) = \nabla_\mathbf{x} \log p(\mathbf{x} \mid y) - \nabla_\mathbf{x} \log p(\mathbf{x})\). In other words, to obtain the gradient of the classifier log-likelihood with respect to the input, all we have to do is subtract the unconditional score function from the conditional score function.</p>
<p>Substituting this expression into the formula for the update direction of classifier guidance, we obtain the following:</p>
\[\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t) + \gamma \cdot \nabla_{\mathbf{x}_t} \log p(y \mid \mathbf{x}_t)\]
\[= \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t) + \gamma \cdot \left( \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t \mid y) - \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t) \right)\]
\[= (1 - \gamma) \cdot \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t) + \gamma \cdot \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t \mid y) .\]
<p>The update direction is now a linear combination of the unconditional and conditional score functions. It would be a convex combination if it were the case that \(\gamma \in [0, 1]\), but in practice \(\gamma > 1\) tends to be were the magic happens, so this is merely a <em>barycentric</em> combination. Note that \(\gamma = 0\) reduces to the unconditional case, and \(\gamma = 1\) reduces to the conditional (<em>unguided</em>) case.</p>
<p>How do we make sense of this geometrically? With our hybrid conditional/unconditional model, we can make two predictions \(\hat{\mathbf{x}}_0\). These will be different, because the conditioning information may allow us to make a more accurate prediction:</p>
<figure>
<a href="/images/geometry_diagram012.png"><img src="/images/geometry_diagram012.png" style="border: 1px dotted #bbb;" alt="Diagram showing the conditional and unconditional predictions." /></a>
<figcaption>Diagram showing the conditional and unconditional predictions.</figcaption>
</figure>
<p>Next, we determine the difference vector between these two predictions. As we showed earlier, this corresponds to the gradient direction provided by the implied Bayesian classifier:</p>
<figure>
<a href="/images/geometry_diagram013.png"><img src="/images/geometry_diagram013.png" style="border: 1px dotted #bbb;" alt="Diagram showing the difference vector obtained by subtracting the directions corresponding to the two predictions." /></a>
<figcaption>Diagram showing the difference vector obtained by subtracting the directions corresponding to the two predictions.</figcaption>
</figure>
<p>We now scale this vector by \(\gamma\):</p>
<figure>
<a href="/images/geometry_diagram014.png"><img src="/images/geometry_diagram014.png" style="border: 1px dotted #bbb;" alt="Diagram showing the amplified difference vector." /></a>
<figcaption>Diagram showing the amplified difference vector.</figcaption>
</figure>
<p>Starting from the unconditional prediction for \(\hat{\mathbf{x}}_0\), this vector points towards a new implicit prediction, which corresponds to a stronger influence of the conditioning signal. This is the prediction we will now take a small step towards:</p>
<figure>
<a href="/images/geometry_diagram014.png"><img src="/images/geometry_diagram015.png" style="border: 1px dotted #bbb;" alt="Diagram showing the direction to step in for classifier-free guidance." /></a>
<figcaption>Diagram showing the direction to step in for classifier-free guidance.</figcaption>
</figure>
<p>Classifier-free guidance tends to work a lot better than classifier guidance, because the Bayesian classifier is much more robust than a separately trained one, and the resulting update directions are much less likely to be adversarial. On top of that, it doesn’t require an auxiliary model, and generative models can be made compatible with classifier-free guidance simply through <em>conditioning dropout</em> during training. On the flip side, that means we can’t use this for post-hoc conditioning – all conditioning signals have to be available during training of the generative model itself. <a href="https://sander.ai/2022/05/26/guidance.html">My previous blog post on guidance</a> covers the differences in more detail.</p>
<p>An animated diagram of a single step of sampling with classifier-free guidance is shown below:</p>
<figure>
<a href="/images/geometry_diagram019.gif"><img src="/images/geometry_diagram019.gif" style="border: 1px dotted #c10;" alt="Animation of a single step of sampling with classifier-free guidance." /></a>
<figcaption>Animation of a single step of sampling with classifier-free guidance.</figcaption>
</figure>
<h2 id="-closing-thoughts"><a name="closing-thoughts"></a> Closing thoughts</h2>
<figure>
<a href="/images/trees_water.jpg"><img src="/images/trees_water.jpg" /></a>
</figure>
<p>What’s surprising about guidance, in my opinion, is how powerful it is in practice, despite its relative simplicity. The modifications to the sampling procedure required to apply guidance are all <strong>linear operations</strong> on vectors in the input space. This is what makes it possible to interpret the procedure geometrically.</p>
<p>How can a set of linear operations affect the outcome of the sampling procedure so profoundly? The key is <strong>iterative refinement</strong>: these simple modifications are applied repeatedly, and crucially, they are interleaved with a very non-linear operation, which is the application of the diffusion model itself, to predict the next update direction. As a result, any linear modification of the update direction has a non-linear effect on the next update direction. Across many sampling steps, the resulting effect is highly non-linear and powerful: small differences in each step accumulate, and result in trajectories with very different endpoints.</p>
<p>I hope the visualisations in this post are a useful complement to <a href="https://sander.ai/2022/05/26/guidance.html">my previous writing on the topic of guidance</a>. Feel free to let me know your thoughts in the comments, or on Twitter/X (<a href="https://twitter.com/sedielem">@sedielem</a>) or Threads (<a href="https://www.threads.net/@sanderdieleman">@sanderdieleman</a>).</p>
<p><em>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</em></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@misc{dieleman2023geometry,
author = {Dieleman, Sander},
title = {The geometry of diffusion guidance},
url = {https://sander.ai/2023/08/28/geometry.html},
year = {2023}
}
</code></pre></div></div>
<h2 id="-acknowledgements"><a name="acknowledgements"></a> Acknowledgements</h2>
<p>Thanks to Bundle for modelling and to kipply for permission to use <a href="https://twitter.com/kipperrii/status/1574557416741474304">this photograph</a>. Thanks to my colleagues at Google DeepMind for various discussions, which continue to shape my thoughts on this topic!</p>
<h2 id="-references"><a name="references"></a> References</h2>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:foundations" role="doc-endnote">
<p>Blum, Hopcroft, Kannan, “<a href="https://www.cs.cornell.edu/jeh/book.pdf">Foundations of Data science</a>”, Cambridge University Press, 2020 <a href="#fnref:foundations" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:cm" role="doc-endnote">
<p>Song, Dhariwal, Chen, Sutskever, “<a href="https://arxiv.org/abs/2303.01469">Consistency Models</a>”, International Conference on Machine Learning, 2023. <a href="#fnref:cm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ddpm" role="doc-endnote">
<p>Ho, Jain, Abbeel, “<a href="https://proceedings.neurips.cc/paper/2020/hash/4c5bcfec8584af0d967f1ab10179ca4b-Abstract.html">Denoising Diffusion Probabilistic Models</a>”, 2020. <a href="#fnref:ddpm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ddim" role="doc-endnote">
<p>Song, Meng, Ermon, “<a href="https://arxiv.org/abs/2010.02502">Denoising Diffusion Implicit Models</a>”, International Conference on Learning Representations, 2021. <a href="#fnref:ddim" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:dpmpp" role="doc-endnote">
<p>Lu, Zhou, Bao, Chen, Li, Zhu, “<a href="https://arxiv.org/abs/2211.01095">DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models</a>”, arXiv, 2022. <a href="#fnref:dpmpp" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sde" role="doc-endnote">
<p>Song, Sohl-Dickstein, Kingma, Kumar, Ermon and Poole, “<a href="https://arxiv.org/abs/2011.13456">Score-Based Generative Modeling through Stochastic Differential Equations</a>”, International Conference on Learning Representations, 2021. <a href="#fnref:sde" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:sde:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:equilibrium" role="doc-endnote">
<p>Sohl-Dickstein, Weiss, Maheswaranathan and Ganguli, “<a href="https://arxiv.org/abs/1503.03585">Deep Unsupervised Learning using Nonequilibrium Thermodynamics</a>”, International Conference on Machine Learning, 2015. <a href="#fnref:equilibrium" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:beatgans" role="doc-endnote">
<p>Dhariwal, Nichol, “<a href="https://arxiv.org/abs/2105.05233">Diffusion Models Beat GANs on Image Synthesis</a>”, Neural Information Processing Systems, 2021. <a href="#fnref:beatgans" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:universal" role="doc-endnote">
<p>Bansal, Chu, Schwarzschild, Sengupta, Goldblum, Geiping, Goldstein, “<a href="https://arxiv.org/abs/2302.07121">Universal Guidance for Diffusion Models</a>”, Computer Vision and Pattern Recognition, 2023. <a href="#fnref:universal" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:cf" role="doc-endnote">
<p>Ho, Salimans, “<a href="https://openreview.net/forum?id=qw8AKxfYbI">Classifier-Free Diffusion Guidance</a>”, NeurIPS workshop on DGMs and Applications”, 2021. <a href="#fnref:cf" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Guidance is a powerful method that can be used to enhance diffusion model sampling. As I’ve discussed in an earlier blog post, it’s almost like a cheat code: it can improve sample quality so much that it’s as if the model had ten times the number of parameters – an order of magnitude improvement, basically for free! This follow-up post provides a geometric interpretation and visualisation of the diffusion sampling procedure, which I’ve found particularly useful to explain how guidance works.Perspectives on diffusion2023-07-20T00:00:00+01:002023-07-20T00:00:00+01:00https://sander.ai/2023/07/20/perspectives<p>Diffusion models appear to come in many shapes and forms. If you pick two random research papers about diffusion and look at how they describe the model class in their respective introductions, chances are they will go about it in very different ways. This can be both frustrating and enlightening: frustrating, because it makes it harder to spot relationships and equivalences across papers and implementations – but also enlightening, because these various perspectives each reveal new connections and are a breeding ground for new ideas. This blog post is an overview of the perspectives on diffusion I’ve found useful.</p>
<p>Last year, I wrote a blog post titled “<a href="https://sander.ai/2022/01/31/diffusion.html">diffusion models are autoencoders</a>”. The title was tongue-in-cheek, but it highlighted a close connection between diffusion models and autoencoders, which I felt had been underappreciated up until then. Since so many more ML practitioners were familiar with autoencoders than with diffusion models, at the time, it seemed like a good idea to try and change that.</p>
<p>Since then, I’ve realised that I could probably write a whole series of blog posts, each highlighting a different perspective or equivalence. Unfortunately I only seem to be able to produce one or two blog posts a year, despite efforts to increase the frequency. So instead, this post will cover all of them at once in considerably less detail – but hopefully enough to pique your curiosity, or to make you see diffusion models in a new light.</p>
<p>This post will probably be most useful to those who already have at least a basic understanding of diffusion models. If you don’t count yourself among this group, or you’d like a refresher, check out my earlier blog posts on the topic:</p>
<ul>
<li><a href="https://sander.ai/2022/01/31/diffusion.html">Diffusion models are autoencoders</a></li>
<li><a href="https://sander.ai/2022/05/26/guidance.html">Guidance: a cheat code for diffusion models</a></li>
<li><a href="https://sander.ai/2023/01/09/diffusion-language.html">Diffusion language models</a></li>
</ul>
<p>Before we start, a <strong>disclaimer</strong>: some of these connections are deliberately quite handwavy. They are intended to build intuition and understanding, and are not supposed to be taken literally, for the most part – this is a blog post, not a peer-reviewed research paper.</p>
<p>That said, I welcome any corrections and thoughts about the ways in which these equivalences don’t quite hold, or could even be misleading. <strong>Feel free to leave a comment, or reach out to me on Twitter (<a href="https://twitter.com/sedielem">@sedielem</a>) or Threads (<a href="https://www.threads.net/@sanderdieleman">@sanderdieleman</a>).</strong> If you have a different perspective that I haven’t covered here, please share it as well.</p>
<p>Alright, here goes (click to scroll to each section):</p>
<ol>
<li><em><a href="#autoencoders">Diffusion models are <strong>autoencoders</strong></a></em></li>
<li><em><a href="#latent">Diffusion models are <strong>deep latent variable models</strong></a></em></li>
<li><em><a href="#score">Diffusion models predict the <strong>score function</strong></a></em></li>
<li><em><a href="#sde">Diffusion models solve <strong>reverse SDEs</strong></a></em></li>
<li><em><a href="#flow">Diffusion models are <strong>flow-based models</strong></a></em></li>
<li><em><a href="#rnn">Diffusion models are <strong>recurrent neural networks</strong></a></em></li>
<li><em><a href="#autoregressive">Diffusion models are <strong>autoregressive models</strong></a></em></li>
<li><em><a href="#expectation">Diffusion models estimate <strong>expectations</strong></a></em></li>
<li><em><a href="#discrete-continuous">Discrete and continuous diffusion models</a></em></li>
<li><em><a href="#alternative">Alternative formulations</a></em></li>
<li><em><a href="#consistency">Consistency</a></em></li>
<li><em><a href="#conventions">Defying conventions</a></em></li>
<li><em><a href="#closing-thoughts">Closing thoughts</a></em></li>
<li><em><a href="#acknowledgements">Acknowledgements</a></em></li>
<li><em><a href="#references">References</a></em></li>
</ol>
<h2 id="-diffusion-models-are-autoencoders"><a name="autoencoders"></a> Diffusion models are autoencoders</h2>
<figure>
<a href="/images/diffuse2.jpg"><img src="/images/diffuse2.jpg" /></a>
</figure>
<p>Denoising autoencoders are neural networks whose input is corrupted by noise, and they are tasked to predict the clean input, i.e. to remove the corruption. Doing well at this task requires learning about the distribution of the clean data. They have been very popular for representation learning, and in the early days of deep learning, they were also used for layer-wise pre-training of deep neural networks<sup id="fnref:bengio" role="doc-noteref"><a href="#fn:bengio" class="footnote" rel="footnote">1</a></sup>.</p>
<p>It turns out that the neural network used in a diffusion model usually solves a very similar problem: given an input example corrupted by noise, it predicts some quantity associated with the data distribution. This can be the corresponding clean input (as in denoising autoencoders), the noise that was added, or something in between (<a href="#conventions">more on that later</a>). All of these are equivalent in some sense when the corruption process is linear, i.e., the noise is additive: we can turn a model that predicts the noise into a model that predicts the clean input, simply by subtracting its prediction from the noisy input. In neural network parlance, we would be adding a residual connection from the input to the output.</p>
<figure style="text-align: center;">
<a href="/images/ae_vs_diffusion_diagram.png"><img src="/images/ae_vs_diffusion_diagram.png" alt="Schematic diagram of a denoising autoencoder (left) and a diffusion model (right)." /></a>
<figcaption>Schematic diagram of a denoising autoencoder (left) and a diffusion model (right).</figcaption>
</figure>
<p>There are a few key differences:</p>
<ul>
<li>
<p>Denoising autoencoders often have some sort of <strong>information bottleneck</strong> somewhere in the middle, to learn a useful representation of the input whose capacity is constrained in some way. The denoising task itself is merely a means to an end, and not what we actually want to use the models for once we’ve trained them. The neural networks used for diffusion models don’t typically have such a bottleneck, as we are more interested in their predictions, rather than the internal representations they construct along the way to be able to make those predictions.</p>
</li>
<li>
<p>Denoising autoencoders can be trained with a variety of types of noise. For example, parts of the input could be masked out (masking noise), or we could add noise drawn from some arbitrary distribution (often Gaussian). For diffusion models, we usually stick with <strong>additive Gaussian noise</strong> because of its helpful mathematical properties, which simplify a lot of operations.</p>
</li>
<li>
<p>Another important difference is that denoising autoencoders are usually trained to deal only with noise of a particular strength. In a diffusion model, we have to be able to make predictions for inputs with a lot of noise, or with very little noise. <strong>The noise level is provided to the neural network as an extra input.</strong></p>
</li>
</ul>
<p>As mentioned, I’ve already discussed this relationship in detail <a href="https://sander.ai/2022/01/31/diffusion.html">in a previous blog post</a>, so check that out if you are keen to explore this connection more thoroughly.</p>
<h2 id="-diffusion-models-are-deep-latent-variable-models"><a name="latent"></a> Diffusion models are deep latent variable models</h2>
<figure>
<a href="/images/stack.jpg"><img src="/images/stack.jpg" /></a>
</figure>
<p>Sohl-Dickstein et al. first suggested using a diffusion process to gradually destroy structure in data, and then constructing a generative model by learning to reverse this process in a 2015 ICML paper<sup id="fnref:noneq" role="doc-noteref"><a href="#fn:noneq" class="footnote" rel="footnote">2</a></sup>. Five years later, Ho et al. built on this to develop <strong>Denoising Diffusion Probabilistic Models</strong> or <strong>DDPMs</strong><sup id="fnref:ddpm" role="doc-noteref"><a href="#fn:ddpm" class="footnote" rel="footnote">3</a></sup>, which formed the blueprint of modern diffusion models along with score-based models (<a href="#score">see below</a>).</p>
<figure style="text-align: center;">
<a href="/images/ddpm.png"><img src="/images/ddpm.png" alt="DDPM graphical model." /></a>
<figcaption>DDPM graphical model.</figcaption>
</figure>
<p>In this formulation, represented by the graphical model above, \(\mathbf{x}_T\) (latent) represents Gaussian noise and \(\mathbf{x}_0\) (observed) represents the data distribution. These random variables are bridged by a finite number of intermediate latent variables \(\mathbf{x}_t\) (typically \(T=1000\)), which form a <strong>Markov chain</strong>, i.e. \(\mathbf{x}_{t-1}\) only depends on \(\mathbf{x}_t\), and not directly on any preceding random variables in the chain.</p>
<p>The parameters of the Markov chain are fit using <strong>variational inference</strong> to reverse a diffusion process, which is itself a Markov chain (in the other direction, represented by \(q(\mathbf{x}_t \mid \mathbf{x}_{t-1})\) in the diagram) that gradually adds Gaussian noise to the data. Concretely, as in Variational Autoencoders (VAEs)<sup id="fnref:vaekingma" role="doc-noteref"><a href="#fn:vaekingma" class="footnote" rel="footnote">4</a></sup><sup id="fnref:vaerezende" role="doc-noteref"><a href="#fn:vaerezende" class="footnote" rel="footnote">5</a></sup>, we can write down an Evidence Lower Bound (ELBO), a bound on the log likelihood, which we can maximise tractably. In fact, this section could just as well have been titled <strong>“diffusion models are deep VAEs”</strong>, but I’ve already used “diffusion models are autoencoders” for a different perspective, so I figured this might have been a bit confusing.</p>
<p>We know \(q(\mathbf{x}_t \mid \mathbf{x}_{t-1})\) is Gaussian by construction, but \(p(\mathbf{x}_{t-1} \mid \mathbf{x}_t)\), which we are trying to fit with our model, need not be! However, as long as each individual step is small enough (i.e. \(T\) is large enough), it turns out that we can parameterise \(p(\mathbf{x}_{t-1} \mid \mathbf{x}_t)\) as if it were Gaussian, and the approximation error will be small enough for this model to still produce good samples. This is kind of surprising when you think about it, as during sampling, any errors may accumulate over \(T\) steps.</p>
<p>Full disclosure: out of all the different perspectives on diffusion in this blog post, this is probably the one I understand least well. Sort of ironic, given how popular it is, but variational inference has always been a little bit mysterious to me. I will stop here, and mostly defer to a few others who have described this perspective in detail (apart from the original DDPM paper, of course):</p>
<ul>
<li><a href="https://angusturner.github.io/generative_models/2021/06/29/diffusion-probabilistic-models-I.html">“Diffusion Models as a kind of VAE” by Angus Turner</a></li>
<li><a href="https://jmtomczak.github.io/blog/10/10_ddgms_lvm_p2.html">Jakub Tomczak’s blog post on DDPMs</a></li>
<li><a href="https://lilianweng.github.io/posts/2021-07-11-diffusion-models/">Lilian Weng’s blog post on diffusion models (connects multiple perspectives)</a></li>
<li><a href="https://blog.alexalemi.com/diffusion.html">Alex Alemi’s blog post about the variational diffusion loss</a></li>
</ul>
<h2 id="-diffusion-models-predict-the-score-function"><a name="score"></a> Diffusion models predict the score function</h2>
<figure>
<a href="/images/darts.jpg"><img src="/images/darts.jpg" /></a>
</figure>
<p>Most likelihood-based generative models parameterise the log-likelihood of an input \(\mathbf{x}\), \(\log p(\mathbf{x} \mid \theta)\), and then fit the model parameters \(\theta\) to maximise it, either approximately (as in VAEs) or exactly (as in flow-based models or autoregressive models). Because log-likelihoods represent probability distributions, and probability distributions have to be normalised, this usually requires some constraints to ensure all possible values for the parameters \(\theta\) yield valid distributions. For example, autoregressive models have causal masking to ensure this, and most flow-based models require invertible neural network architectures.</p>
<p>It turns out there is another way to fit distributions that neatly sidesteps this normalisation requirement, called <strong>score matching</strong><sup id="fnref:scorematching" role="doc-noteref"><a href="#fn:scorematching" class="footnote" rel="footnote">6</a></sup>. It’s based on the observation that the so-called <strong>score function</strong>, \(s_\theta(\mathbf{x}) := \nabla_\mathbf{x} \log p(\mathbf{x} \mid \theta)\), is invariant to the scaling of \(p(\mathbf{x} \mid \theta)\). This is easy to see:</p>
\[\nabla_\mathbf{x} \log \left( \alpha \cdot p(\mathbf{x} \mid \theta) \right) = \nabla_\mathbf{x} \left( \log \alpha + \log p(\mathbf{x} \mid \theta) \right)\]
\[= \nabla_\mathbf{x} \log \alpha + \nabla_\mathbf{x} \log p(\mathbf{x} \mid \theta) = 0 + \nabla_\mathbf{x} \log p(\mathbf{x} \mid \theta) .\]
<p>Any arbitrary scale factor applied to the probability density simply disappears. Therefore, if we have a model that parameterises a score estimate \(\hat{s}_\theta(\mathbf{x})\) directly, we can fit the distribution by minimising the <strong>score matching loss</strong> (instead of maximising the likelihood directly):</p>
\[\mathcal{L}_{SM} := \left( \hat{s}_\theta(\mathbf{x}) - \nabla_\mathbf{x} \log p(\mathbf{x}) \right)^2 .\]
<p>In this form however, this loss function is not practical, because we do not have a good way to compute ground truth scores \(\nabla_\mathbf{x} \log p(\mathbf{x})\) for any data point \(\mathbf{x}\). There are a few tricks that can be applied to sidestep this requirement, and transform this into a loss function that’s easy to compute, including <em>implicit score matching (ISM)</em><sup id="fnref:scorematching:1" role="doc-noteref"><a href="#fn:scorematching" class="footnote" rel="footnote">6</a></sup>, <em>sliced score matching (SSM)</em><sup id="fnref:ssm" role="doc-noteref"><a href="#fn:ssm" class="footnote" rel="footnote">7</a></sup> and <em>denoising score matching (DSM)</em><sup id="fnref:dsm" role="doc-noteref"><a href="#fn:dsm" class="footnote" rel="footnote">8</a></sup>. We’ll take a closer look at this last one:</p>
\[\mathcal{L}_{DSM} := \left( \hat{s}_\theta(\tilde{\mathbf{x}}) - \nabla_\tilde{\mathbf{x}} \log p(\tilde{\mathbf{x}} \mid \mathbf{x}) \right)^2 .\]
<p>Here, \(\tilde{\mathbf{x}}\) is obtained by adding Gaussian noise to \(\mathbf{x}\). This means \(p(\tilde{\mathbf{x}} \mid \mathbf{x})\) is distributed according to a Gaussian distribution \(\mathcal{N}\left(\mathbf{x}, \sigma^2\right)\) and the ground truth conditional score function can be calculated in closed form:</p>
\[\nabla_\tilde{\mathbf{x}} \log p(\tilde{\mathbf{x}} \mid \mathbf{x}) = \nabla_\tilde{\mathbf{x}} \log \left( \frac{1}{\sigma \sqrt{2 \pi}} e^{ -\frac{1}{2} \left( \frac{\tilde{\mathbf{x}} - \mathbf{x}}{\sigma} \right)^2 } \right)\]
\[= \nabla_\tilde{\mathbf{x}} \log \frac{1}{\sigma \sqrt{2 \pi}} - \nabla_\tilde{\mathbf{x}} \left( \frac{1}{2} \left( \frac{\tilde{\mathbf{x}} - \mathbf{x}}{\sigma} \right)^2 \right) = 0 - \frac{1}{2} \cdot 2 \left( \frac{\tilde{\mathbf{x}} - \mathbf{x}}{\sigma} \right) \cdot \frac{1}{\sigma} = \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma^2}.\]
<p>This form has a very intuitive interpretation: it is a scaled version of the Gaussian noise added to \(\mathbf{x}\) to obtain \(\tilde{\mathbf{x}}\). Therefore, <strong>making \(\tilde{\mathbf{x}}\) more likely by following the score (= gradient ascent on the log-likelihood) directly corresponds to removing (some of) the noise</strong>:</p>
\[\tilde{\mathbf{x}} + \eta \cdot \nabla_\tilde{\mathbf{x}} \log p(\tilde{\mathbf{x}} \mid \mathbf{x}) = \tilde{\mathbf{x}} + \frac{\eta}{\sigma^2} \left(\mathbf{x} - \tilde{\mathbf{x}}\right) = \frac{\eta}{\sigma^2} \mathbf{x} + \left(1 - \frac{\eta}{\sigma^2}\right) \tilde{\mathbf{x}} .\]
<p>If we choose the step size \(\eta = \sigma^2\), we recover the clean data \(\mathbf{x}\) in a single step.</p>
<p>\(\mathcal{L}_{SM}\) and \(\mathcal{L}_{DSM}\) are different loss functions, but the neat thing is that they have <strong>the same minimum</strong> in expectation: \(\mathbb{E}_\mathbf{x} [\mathcal{L}_{SM}] = \mathbb{E}_{\mathbf{x},\tilde{\mathbf{x}}} [\mathcal{L}_{DSM}] + C\), where \(C\) is some constant. Pascal Vincent derived this equivalence back in 2010 (before score matching was cool!) and I strongly recommend reading his tech report about it<sup id="fnref:dsm:1" role="doc-noteref"><a href="#fn:dsm" class="footnote" rel="footnote">8</a></sup> if you want to deepen your understanding.</p>
<p>One important question this approach raises is: how much noise should we add, i.e. <strong>what should \(\sigma\) be?</strong> Picking a particular fixed value for this hyperparameter doesn’t actually work very well in practice. At low noise levels, it is very difficult to estimate the score accurately in low-density regions. At high noise levels, this is less of a problem, because the added noise spreads out the density in all directions – but then the distribution that we’re modelling is significantly distorted by the noise. What works well is to <strong>model the density at many different noise levels</strong>. Once we have such a model, we can <em>anneal</em> \(\sigma\) during sampling, starting with lots of noise and gradually dialing it down. Song & Ermon describe these issues and their elegant solution in detail in their 2019 paper<sup id="fnref:songermon" role="doc-noteref"><a href="#fn:songermon" class="footnote" rel="footnote">9</a></sup>.</p>
<p>This combination of denoising score matching at many different noise levels with gradual annealing of the noise during sampling yields a model that’s essentially equivalent to a DDPM, but the derivation is completely different – no ELBOs in sight! To learn more about this perspective, check out <a href="https://yang-song.net/blog/2021/score/">Yang Song’s excellent blog post on the topic</a>.</p>
<h2 id="-diffusion-models-solve-reverse-sdes"><a name="sde"></a> Diffusion models solve reverse SDEs</h2>
<figure>
<a href="/images/backward.jpg"><img src="/images/backward.jpg" /></a>
</figure>
<p>In both of the previous perspectives (deep latent variable models and score matching), we consider a discete and finite set of steps. These steps correspond to different levels of Gaussian noise, and we can write down a monotonic mapping \(\sigma(t)\) which maps the step index \(t\) to the standard deviation of the noise at that step.</p>
<p>If we let the number of steps go to infinity, it makes sense to replace the discrete index variable with a continuous value \(t\) on an interval \([0, T]\), which can be interpreted as a <em>time</em> variable, i.e. \(\sigma(t)\) now describes the evolution of the standard deviation of the noise over time. In continuous time, we can describe the diffusion process which gradually adds noise to data points \(\mathbf{x}\) with a <strong>stochastic differential equation</strong> (SDE):</p>
\[\mathrm{d} \mathbf{x} = \mathbf{f}(\mathbf{x}, t) \mathrm{d}t + g(t) \mathrm{d} \mathbf{w} .\]
<p>This equation relates an infinitesimal change in \(\mathbf{x}\) with an infintesimal change in \(t\), and \(\mathrm{d}\mathbf{w}\) represents <em>infinitesimal Gaussian noise</em>, also known as the <em>Wiener process</em>. \(\mathbf{f}\) and \(g\) are called the <em>drift</em> and <em>diffusion</em> coefficients respectively. Particular choices for \(\mathbf{f}\) and \(g\) yield time-continuous versions of the Markov chains used to formulate DDPMs.</p>
<p>SDEs combine differential equations with stochastic random variables, which can seem a bit daunting at first. Luckily we don’t need too much of the advanced SDE machinery that exists to understand how this perspective can be useful for diffusion models. However, there is one very important result that we can make use of. Given an SDE that describes a diffusion process like the one above, <strong>we can write down another SDE that describes the process in the other direction, i.e. reverses time</strong><sup id="fnref:anderson" role="doc-noteref"><a href="#fn:anderson" class="footnote" rel="footnote">10</a></sup>:</p>
\[\mathrm{d}\mathbf{x} = \left(\mathbf{f}(\mathbf{x}, t) - g(t)^2 \nabla_\mathbf{x} \log p_t(\mathbf{x}) \right) \mathrm{d}t + g(t) \mathrm{d} \bar{\mathbf{w}} .\]
<p>This equation also describes a diffusion process. \(\mathrm{d}\bar{\mathbf{w}}\) is the reversed Wiener process, and \(\nabla_\mathbf{x} \log p_t(\mathbf{x})\) is the time-dependent score function. The time dependence comes from the fact that the noise level changes over time.</p>
<p>Explaining why this is the case is beyond the scope of this blog post, but the original paper by Yang Song and colleagues that introduced the SDE-based formalism for diffusion models<sup id="fnref:sde" role="doc-noteref"><a href="#fn:sde" class="footnote" rel="footnote">11</a></sup> is well worth a read.</p>
<p>Concretely, if we have a way to estimate the time-dependent score function, we can simulate the reverse diffusion process, and therefore draw samples from the data distribution starting from noise. So we can once again train a neural network to predict this quantity, and plug it into the reverse SDE to obtain a <em>continuous-time diffusion model</em>.</p>
<p>In practice, simulating this SDE requires discretising the time variable \(t\) again, so you might wonder what the point of all this is. What’s neat is that this discretisation is now something we can decide at sampling-time, and it does not have to be fixed before we train our score prediction model. In other words, we can trade off sample quality for computational cost in a very natural way without changing the model, by choosing the number of sampling steps.</p>
<h2 id="-diffusion-models-are-flow-based-models"><a name="flow"></a> Diffusion models are flow-based models</h2>
<figure>
<a href="/images/waterfall.jpg"><img src="/images/waterfall.jpg" /></a>
</figure>
<p>Remember flow-based models<sup id="fnref:nice" role="doc-noteref"><a href="#fn:nice" class="footnote" rel="footnote">12</a></sup> <sup id="fnref:realnvp" role="doc-noteref"><a href="#fn:realnvp" class="footnote" rel="footnote">13</a></sup>? They aren’t very popular for generative modelling these days, which I think is mainly because they tend to require more parameters than other types of models to achieve the same level of performance. This is due to their limited expressivity: neural networks used in flow-based models are required to be invertible, and the log-determinant of the Jacobian must be easy to compute, which imposes significant constraints on the kinds of computations that are possible.</p>
<p>At least, this is the case for <em>discrete</em> normalising flows. <strong>Continuous normalising flows (CNFs)</strong><sup id="fnref:node" role="doc-noteref"><a href="#fn:node" class="footnote" rel="footnote">14</a></sup> <sup id="fnref:ffjord" role="doc-noteref"><a href="#fn:ffjord" class="footnote" rel="footnote">15</a></sup> also exist, and usually take the form of an <em>ordinary differential equation</em> (ODE) parameterised by a neural network, which describes a deterministic path between samples from the data distribution and corresponding samples from a simple base distribution (e.g. standard Gaussian). CNFs are not affected by the aforementioned neural network architecture constraints, but in their original form, they require backpropagation through an ODE solver to train. Although some tricks exist to do this more efficiently, this probably also presents a barrier to widespread adoption.</p>
<p>Let’s revisit the SDE formulation of diffusion models, which describes a stochastic process mapping samples from a simple base distribution to samples from the data distribution. An interesting question to ask is: <strong>what does the distribution of the intermediate samples \(p_t(\mathbf{x})\) look like, and how does it evolve over time?</strong> This is governed by the so-called <strong>Fokker-Planck equation</strong>. If you want to see what this looks like in practice, check out appendix D.1 of Song et al. (2021)<sup id="fnref:sde:1" role="doc-noteref"><a href="#fn:sde" class="footnote" rel="footnote">11</a></sup>.</p>
<p>Here’s where it gets wild: <strong>there exists an ODE that describes a <em>deterministic</em> process whose time-dependent distributions are exactly the same as those of the <em>stochastic</em> process described by the SDE.</strong> This is called the <strong>probability flow ODE</strong>. What’s more, it has a simple closed form:</p>
\[\mathrm{d} \mathbf{x} = \left( \mathbf{f}(\mathbf{x}, t) - \frac{1}{2}g(t)^2 \nabla_\mathbf{x} \log p_t(\mathbf{x}) \right)\mathrm{d}t .\]
<p>This equation describes both the forward and backward process (just flip the sign to go in the other direction), and note that the time-dependent score function \(\nabla_\mathbf{x} \log p_t(\mathbf{x})\) once again features. To prove this, you can write down the Fokker-Planck equations for both the SDE and the probability flow ODE, and do some algebra to show that they are the same, and hence must have the same solution \(p_t(\mathbf{x})\).</p>
<p>Note that this ODE does not describe the <em>same</em> process as the SDE: that would be impossible, because a deterministic differential equation cannot describe a stochastic process. Instead, it describes a <em>different</em> process with the unique property that the distributions \(p_t(\mathbf{x})\) are the same for both processes. Check out the probability flow ODE section in <a href="https://yang-song.net/blog/2021/score/#probability-flow-ode">Yang Song’s blog post</a> for a great diagram comparing both processes.</p>
<p>The implications of this are profound: <strong>there is now a <em>bijective mapping</em> between particular samples from the simple base distribution, and samples from the data distribution</strong>. We have a sampling process where all the randomness is contained in the initial base distribution sample – once that’s been sampled, going from there to a data sample is completely deterministic. It also means that we can map data points to their corresponding latent representations by simulating the ODE forward, manipulating them, and then mapping them back to the data space by simulating the ODE backward.</p>
<p>The model described by the probability flow ODE <em>is</em> a continuous normalising flow, but it’s one that we managed to train without having to backpropagate through an ODE, rendering the approach much more scalable.</p>
<p>The fact that all this is possible, without even changing anything about how the model is trained, still feels like magic to me. We can plug our score predictor into the reverse SDE from the previous section, or the ODE from this one, and get out two different generative models that model the same distribution in different ways. How cool is that?</p>
<p>As a bonus, the probability flow ODE also enables <strong>likelihood computation</strong> for diffusion models (see appendix D.2 of Song et al. (2021)<sup id="fnref:sde:2" role="doc-noteref"><a href="#fn:sde" class="footnote" rel="footnote">11</a></sup>). This also requires solving the ODE, so it’s roughly as expensive as sampling.</p>
<p>For all of the reasons above, the probability flow ODE paradigm has proven quite popular recently. Among other examples, it is used by Karras et al.<sup id="fnref:elucidating" role="doc-noteref"><a href="#fn:elucidating" class="footnote" rel="footnote">16</a></sup> as a basis for their work investigating various diffusion modelling design choices, and my colleagues and I recently used it for our work on diffusion language models<sup id="fnref:cdcd" role="doc-noteref"><a href="#fn:cdcd" class="footnote" rel="footnote">17</a></sup>. It has also been generalised and extended beyond diffusion processes, to enable learning a mapping between any pair of distributions, e.g. in the form of Flow Matching<sup id="fnref:flowmatching" role="doc-noteref"><a href="#fn:flowmatching" class="footnote" rel="footnote">18</a></sup>, Rectified Flows<sup id="fnref:rectifiedflow" role="doc-noteref"><a href="#fn:rectifiedflow" class="footnote" rel="footnote">19</a></sup> and Stochastic Interpolants<sup id="fnref:stochasticinterpolants" role="doc-noteref"><a href="#fn:stochasticinterpolants" class="footnote" rel="footnote">20</a></sup>.</p>
<p><em>Side note:</em> another way to obtain a deterministic sampling process for diffusion models is given by DDIM<sup id="fnref:ddim" role="doc-noteref"><a href="#fn:ddim" class="footnote" rel="footnote">21</a></sup>, which is based on the deep latent variable model perspective.</p>
<h2 id="-diffusion-models-are-recurrent-neural-networks-rnns"><a name="rnn"></a> Diffusion models are recurrent neural networks (RNNs)</h2>
<figure>
<a href="/images/spiral_staircase.jpg"><img src="/images/spiral_staircase.jpg" /></a>
</figure>
<p>Sampling from a diffusion model involves making repeated predictions with a neural network and using those predictions to update a <em>canvas</em>, which starts out filled with random noise. If we consider the full computational graph of this process, it starts to look a lot like a recurrent neural network (RNN). In RNNs, there is a <em>hidden state</em> which repeatedly gets updated by passing it through a recurrent <em>cell</em>, which consists of one or more nonlinear parameterised operations (e.g. the gating mechanisms of LSTMs<sup id="fnref:lstm" role="doc-noteref"><a href="#fn:lstm" class="footnote" rel="footnote">22</a></sup>). Here, the hidden state is the canvas, so it lives in the input space, and the cell is formed by the <em>denoiser</em> neural network that we’ve trained for our diffusion model.</p>
<figure style="text-align: center;">
<a href="/images/sampling_loop.png"><img src="/images/sampling_loop.png" alt="Schematic diagram of the unrolled diffusion sampling loop." /></a>
<figcaption>Schematic diagram of the unrolled diffusion sampling loop.</figcaption>
</figure>
<p>RNNs are usually trained with backpropagation through time (BPTT), with gradients propagated through the recurrence. The number of recurrent steps to backpropagate through is often limited to some maximum number to reduce the computational cost, which is referred to as truncated BPTT. Diffusion models are also trained by backpropagation, but only through one step at a time. In some sense, <strong>diffusion models present a way to train deep recurrent neural networks without backpropagating through the recurrence at all</strong>, yielding a much more scalable training procedure.</p>
<p>RNNs are usually deterministic, so this analogy makes the most sense for the deterministic process based on the probability flow ODE described in the previous section – though injecting noise into the hidden state of RNNs as a means of regularisation is not unheard of, so I think the analogy also works for the stochastic process.</p>
<p>The total depth of this computation graph in terms of the number of nonlinear layers is given by the number of layers in our neural network, multiplied by the number of sampling steps. We can look at the unrolled recurrence as a very deep neural network in its own right, with potentially thousands of layers. This is a lot of depth, but it stands to reason that a challenging task like generative modelling of real-world data requires such deep computation graphs.</p>
<p>We can also consider what happens if we do not use the same neural network at each diffusion sampling step, but potentially different ones for different ranges of noise levels. These networks can be trained separately and independently, and can even have different architectures. This means we are effectively <strong>“untying the weights”</strong> in our very deep network, turning it from an RNN into a plain old deep neural network, but we are still able to avoid having to backpropagate through all of it in one go. Stable Diffusion XL<sup id="fnref:sdxl" role="doc-noteref"><a href="#fn:sdxl" class="footnote" rel="footnote">23</a></sup> uses this approach to great effect for its “Refiner” model, so I think it might start to catch on.</p>
<p>When I started my PhD in 2010, training neural networks with more than two hidden layers was a chore: backprop didn’t work well out of the box, so we used unsupervised layer-wise pre-training<sup id="fnref:bengio:1" role="doc-noteref"><a href="#fn:bengio" class="footnote" rel="footnote">1</a></sup> <sup id="fnref:dbns" role="doc-noteref"><a href="#fn:dbns" class="footnote" rel="footnote">24</a></sup> to find a good initialisation which would make backpropagation possible. Nowadays, even hundreds of nonlinear layers do not form an obstacle anymore. Therefore <strong>it’s not inconceivable that several years from now, training networks with tens of thousands of layers by backprop will be within reach</strong>. At that point, the “divide and conquer” approach that diffusion models offer might lose its luster, and perhaps we’ll all go back to training deep variational autoencoders! (Note that the same “divide and conquer” perspective equally applies to autoregressive models, so they would become obsolete as well, in that case.)</p>
<p>One question this perspective raises is whether diffusion models might actually work better if we backpropagated through the sampling procedure for two or more steps. This approach isn’t popular, which probably indicates that it isn’t cost-effective in practice. There is one important exception (sort of): models which use <em>self-conditioning</em><sup id="fnref:selfcond" role="doc-noteref"><a href="#fn:selfcond" class="footnote" rel="footnote">25</a></sup>, such as Recurrent Interface Networks (RINs)<sup id="fnref:rin" role="doc-noteref"><a href="#fn:rin" class="footnote" rel="footnote">26</a></sup>, pass some form of state between the diffusion sampling steps, in addition to the updated canvas. To enable the model to learn to make use of this state, an approximation of it is made available during training by running an additional forward pass. There is no additional backward pass though, so this doesn’t really count as two steps of BPTT – more like 1.5 steps.</p>
<h2 id="-diffusion-models-are-autoregressive-models"><a name="autoregressive"></a> Diffusion models are autoregressive models</h2>
<figure>
<a href="/images/arguidance.jpg"><img src="/images/arguidance.jpg" /></a>
</figure>
<p>For diffusion models of natural images, <strong>the sampling process tends to produce large-scale structure first, and then iteratively adds more and more fine-grained details</strong>. Indeed, there seems to be almost a direct correspondence between noise levels and feature scales, which I discussed in more detail in Section 5 of <a href="https://sander.ai/2022/01/31/diffusion.html#scale">a previous blog post</a>.</p>
<p>But why is this the case? To understand this, it helps to think in terms of spatial frequencies. Large-scale features in images correspond to low spatial frequencies, whereas fine-grained details correspond to high frequencies. We can decompose images into their spatial frequency components using the 2D Fourier transform (or some variant of it). This is often the first step in image compression algorithms, because the human visual system is known to be much less sensitive to high frequencies, and this can be exploited by compressing them more aggressively than low frequencies.</p>
<figure style="text-align: center;">
<a href="/images/dct.png"><img src="/images/dct.png" alt="Visualisation of the spatial frequency components of the 8x8 discrete cosine transform, used in e.g. JPEG." /></a>
<figcaption>Visualisation of the spatial frequency components of the 8x8 discrete cosine transform, used in e.g. JPEG.</figcaption>
</figure>
<p>Natural images, along with many other natural signals, exhibit an interesting phenomenon in the frequency domain: the magnitude of different frequency components tends to drop off proportionally to the inverse of the frequency<sup id="fnref:imagestats" role="doc-noteref"><a href="#fn:imagestats" class="footnote" rel="footnote">27</a></sup>: \(S(f) \propto 1/f\) (or the inverse of the square of the frequency, if you’re looking at power spectra instead of magnitude spectra).</p>
<p>Gaussian noise, on the other hand, has a flat spectrum: in expectation, all frequencies have the same magnitude. Since the Fourier transform is a linear operation, adding Gaussian noise to a natural image yields a new image whose spectrum is the sum of the spectrum of the original image, and the flat spectrum of the noise. In the log-domain, this superposition of the two spectra looks like a hinge, which shows how the addition of noise obscures any structure present in higher spatial frequencies (see figure below). The larger the standard deviation of this noise, the more spatial frequencies will be affected.</p>
<figure style="text-align: center;">
<a href="/images/image_spectra.png"><img src="/images/image_spectra.png" alt="Magnitude spectra of natural images, Gaussian noise, and noisy images." /></a>
<figcaption>Magnitude spectra of natural images, Gaussian noise, and noisy images.</figcaption>
</figure>
<p>Since diffusion models are constructed by progressively adding more noise to input examples, we can say that this process increasingly drowns out lower and lower frequency content, until all structure is erased (for natural images, at least). When sampling from the model, we go in the opposite direction and effectively add structure at higher and higher spatial frequencies. This basically looks like <strong>autoregression, but in frequency space</strong>! Rissanen et al. (2023) discuss this observation in Section 2.2 of their paper<sup id="fnref:heat" role="doc-noteref"><a href="#fn:heat" class="footnote" rel="footnote">28</a></sup> on generative modelling with inverse heat dissipation (as an alternative to Gaussian diffusion), though they do not make the connection to autoregressive models. I added that bit, so this section could have a provocative title.</p>
<p>An important caveat is that this interpretation relies on the frequency characteristics of natural signals, so for applications of diffusion models in other domains (e.g. language modelling, see Section 2 of <a href="https://sander.ai/2023/01/09/diffusion-language.html#match">my blog post on diffusion language models</a>), the analogy may not make sense.</p>
<h2 id="-diffusion-models-estimate-expectations"><a name="expectation"></a> Diffusion models estimate expectations</h2>
<figure>
<a href="/images/measuring_tape.jpg"><img src="/images/measuring_tape.jpg" /></a>
</figure>
<p>Consider the transition density \(p(\mathbf{x}_t \mid \mathbf{x}_0)\), which describes the distribution of the noisy data example \(\mathbf{x}_t\) at time \(t\), conditioned on the original clean input \(\mathbf{x}_0\) it was derived from (by adding noise). Based on samples from this distribution, the neural network used in a diffusion model is tasked to predict the expectation \(\mathbb{E}[\mathbf{x}_0 \mid \mathbf{x}_t]\) (or some linear time-dependent function of it). This may seem a tad obvious, but I wanted to highlight some of the implications.</p>
<p>First, it provides another motivation for why the mean squared error (MSE) is the right loss function to use for training diffusion models. During training, the expectation \(\mathbb{E}[\mathbf{x}_0 \mid \mathbf{x}_t]\) is not known, so instead we supervise the model using \(\mathbf{x}_0\) itself. Because <strong>the minimiser of the MSE loss is precisely the expectation</strong>, we end up recovering (an approximation of) \(\mathbb{E}[\mathbf{x}_0 \mid \mathbf{x}_t]\), even though we don’t know this quantity a priori. This is a bit different from typical supervised learning problems, where the ideal outcome would be for the model to predict exactly the targets used to supervise it (barring any label errors). Here, we purposely do not want that. More generally, the notion of being able to estimate conditional expectations, even though we only provide supervision through samples, is very powerful.</p>
<p>Second, it explains why distillation<sup id="fnref:distillation" role="doc-noteref"><a href="#fn:distillation" class="footnote" rel="footnote">29</a></sup> of diffusion models<sup id="fnref:progressive" role="doc-noteref"><a href="#fn:progressive" class="footnote" rel="footnote">30</a></sup> <sup id="fnref:guided" role="doc-noteref"><a href="#fn:guided" class="footnote" rel="footnote">31</a></sup> <sup id="fnref:tract" role="doc-noteref"><a href="#fn:tract" class="footnote" rel="footnote">32</a></sup> is such a compelling proposition: in this setting, we are able to supervise a diffusion model <em>directly</em> with an approximation of the target expectation \(\mathbb{E}[\mathbf{x}_0 \mid \mathbf{x}_t]\) that we want it to predict, because that is what the teacher model already provides. As a result, the variance of the training loss will be much lower than if we had trained the model from scratch, and convergence will be much faster. Of course, this is only useful if you already have a trained model on hand to use as a teacher.</p>
<h2 id="-discrete-and-continuous-diffusion-models"><a name="discrete-continuous"></a> Discrete and continuous diffusion models</h2>
<figure>
<a href="/images/discrete.jpg"><img src="/images/discrete.jpg" /></a>
</figure>
<p>So far, we have covered several perspectives that consider a finite set of discrete noise levels, and several perspectives that use a notion of continuous time, combined with a mapping function \(\sigma(t)\) to map time steps to the corresponding standard deviation of the noise. These are typically referred to as <strong>discrete-time</strong> and <strong>continuous-time</strong> respectively. One thing that’s quite neat is that this is mostly a matter of interpretation: models trained within a discrete-time perspective can usually be repurposed quite easily to work in the continuous-time setting<sup id="fnref:elucidating:1" role="doc-noteref"><a href="#fn:elucidating" class="footnote" rel="footnote">16</a></sup>, and vice versa.</p>
<p>Another way in which diffusion models can be discrete or continuous, is <strong>with respect to the input space</strong>. In the literature, I’ve found that it is sometimes unclear whether “continuous” or “discrete” are meant to be with respect to time, or with respect to the input. This is especially important because some perspectives only really make sense for continuous input, as they rely on gradients with respect to the input (i.e. all perspectives based on the score function).</p>
<p>All four combinations of discreteness/continuity exist:</p>
<ul>
<li><strong>discrete time, continuous input</strong>: the original deep latent variable model perspective (DDPMs), as well as the score-based perspective;</li>
<li><strong>continuous time, continuous input</strong>: SDE- and ODE-based perspectives;</li>
<li><strong>discrete time, discrete input</strong>: D3PM<sup id="fnref:d3pm" role="doc-noteref"><a href="#fn:d3pm" class="footnote" rel="footnote">33</a></sup>, MaskGIT<sup id="fnref:maskgit" role="doc-noteref"><a href="#fn:maskgit" class="footnote" rel="footnote">34</a></sup>, Mask-predict<sup id="fnref:maskpredict" role="doc-noteref"><a href="#fn:maskpredict" class="footnote" rel="footnote">35</a></sup>, ARDM<sup id="fnref:ardm" role="doc-noteref"><a href="#fn:ardm" class="footnote" rel="footnote">36</a></sup>, Multinomial diffusion<sup id="fnref:multinomial" role="doc-noteref"><a href="#fn:multinomial" class="footnote" rel="footnote">37</a></sup> and SUNDAE<sup id="fnref:sundae" role="doc-noteref"><a href="#fn:sundae" class="footnote" rel="footnote">38</a></sup> are all methods that use iterative refinement on discrete inputs – whether all of these should be considered diffusion models isn’t entirely clear (it depends on who you ask);</li>
<li><strong>continuous time, discrete input</strong>: Continuous Time Markov Chains (CTMCs)<sup id="fnref:ctmc" role="doc-noteref"><a href="#fn:ctmc" class="footnote" rel="footnote">39</a></sup>, Score-based Continuous-time Discrete Diffusion Models<sup id="fnref:discretescore" role="doc-noteref"><a href="#fn:discretescore" class="footnote" rel="footnote">40</a></sup> and Blackout Diffusion<sup id="fnref:blackout" role="doc-noteref"><a href="#fn:blackout" class="footnote" rel="footnote">41</a></sup> all pair discrete input with continuous time – this setting is also often handled by embedding discrete data in Euclidean space, and then performing input-continuous diffusion in that space, as in e.g. Analog Bits<sup id="fnref:selfcond:1" role="doc-noteref"><a href="#fn:selfcond" class="footnote" rel="footnote">25</a></sup>, Self-conditioned Embedding Diffusion<sup id="fnref:sed" role="doc-noteref"><a href="#fn:sed" class="footnote" rel="footnote">42</a></sup> and CDCD<sup id="fnref:cdcd:1" role="doc-noteref"><a href="#fn:cdcd" class="footnote" rel="footnote">17</a></sup>.</li>
</ul>
<h2 id="-alternative-formulations"><a name="alternative"></a> Alternative formulations</h2>
<figure>
<a href="/images/adhoc.jpg"><img src="/images/adhoc.jpg" /></a>
</figure>
<p>Recently, a few papers have proposed new derivations of this class of models from first principles with the benefit of hindsight, avoiding concepts such as differential equations, ELBOs or score matching altogether. These works provide yet another perspective on diffusion models, which may be more accessible because it requires less background knowledge.</p>
<p><strong>Inversion by Direct Iteration (InDI)</strong><sup id="fnref:indi" role="doc-noteref"><a href="#fn:indi" class="footnote" rel="footnote">43</a></sup> is a formulation rooted in image restoration, intended to harness iterative refinement to improve perceptual quality. No assumptions are made about the nature of the image degradations, and models are trained on paired low-quality and high-quality examples. <strong>Iterative \(\alpha\)-(de)blending</strong><sup id="fnref:deblend" role="doc-noteref"><a href="#fn:deblend" class="footnote" rel="footnote">44</a></sup> uses linear interpolation between samples from two different distributions as a starting point to obtain a deterministic mapping between the distributions. Both of these methods are also closely related to Flow Matching<sup id="fnref:flowmatching:1" role="doc-noteref"><a href="#fn:flowmatching" class="footnote" rel="footnote">18</a></sup>, Rectified Flow<sup id="fnref:rectifiedflow:1" role="doc-noteref"><a href="#fn:rectifiedflow" class="footnote" rel="footnote">19</a></sup> and Stochastic Interpolants<sup id="fnref:stochasticinterpolants:1" role="doc-noteref"><a href="#fn:stochasticinterpolants" class="footnote" rel="footnote">20</a></sup> discussed earlier.</p>
<h2 id="-consistency"><a name="consistency"></a> Consistency</h2>
<figure>
<a href="/images/consistency.jpg"><img src="/images/consistency.jpg" /></a>
</figure>
<p>A few different notions of “consistency” in diffusion models have arisen in literature recently:</p>
<ul>
<li>
<p><strong>Consistency models (CM)</strong><sup id="fnref:cm" role="doc-noteref"><a href="#fn:cm" class="footnote" rel="footnote">45</a></sup> are trained to map points on any trajectory of the probability flow ODE to the trajectory’s origin (i.e. the clean data point), enabling sampling in a single step. This is done indirectly by taking pairs of points on a particular trajectory and ensuring that the model output is the same for both (hence “consistency”). There is a distillation variant which starts from an existing diffusion model, but it is also possible to train a consistency model from scratch.</p>
</li>
<li>
<p><strong>Consistent diffusion models (CDM)</strong><sup id="fnref:cdm" role="doc-noteref"><a href="#fn:cdm" class="footnote" rel="footnote">46</a></sup> are trained using a regularisation term that explicitly encourages consistency, which they define to mean that the prediction of the denoiser should correspond to the conditional expectation \(\mathbb{E}[\mathbf{x}_0 \mid \mathbf{x}_t]\) (see <a ref="#expectation">earlier</a>).</p>
</li>
<li>
<p><strong>FP-Diffusion</strong><sup id="fnref:fpdiffusion" role="doc-noteref"><a href="#fn:fpdiffusion" class="footnote" rel="footnote">47</a></sup> takes the Fokker-Planck equation describing the evolution across time of \(p_t(\mathbf{x})\), and introduces an explicit regularisation term to ensure that it holds.</p>
</li>
</ul>
<p>Each of these properties would trivially hold for an ideal diffusion model (i.e. fully converged, in the limit of infinite capacity). However, real diffusion models are approximate, and so they tend not to hold in practice, which is why it makes sense to add mechanisms to explicitly enforce them.</p>
<p>The main reason for including this section here is that I wanted to highlight a recent paper by Lai et al. (2023)<sup id="fnref:equivalenceconsistency" role="doc-noteref"><a href="#fn:equivalenceconsistency" class="footnote" rel="footnote">48</a></sup> that shows that these three different notions of consistency are essentially different perspectives on the same thing. I thought this was a very elegant result, and it definitely suits the theme of this blog post!</p>
<h2 id="-defying-conventions"><a name="conventions"></a> Defying conventions</h2>
<figure>
<a href="/images/split.jpg"><img src="/images/split.jpg" /></a>
</figure>
<p>Apart from all these different perspectives on a conceptual level, the diffusion literature is also particularly fraught in terms of reinventing notation and defying conventions, in my experience. Sometimes, even two different descriptions of the <em>same</em> conceptual perspective look nothing alike. This doesn’t help accessibility and increases the barrier to entry. (I’m not blaming anyone for this, to be clear – in fact, I suspect I might be contributing to the problem with this blog post. Sorry about that.)</p>
<p>There are also a few other seemingly innocuous details and parameterisation choices that can have profound implications. Here are three things to watch out for:</p>
<ul>
<li>
<p>By and large, people use <strong>variance-preserving</strong> (VP) diffusion processes, where in addition to adding noise at each step, the current canvas is rescaled to preserve the overall variance. However, the <strong>variance-exploding</strong> (VE) formulation, where no rescaling happens and the variance of the added noise increases towards infinity, has also gained some followers. Most notably it is used by Karras et al. (2022)<sup id="fnref:elucidating:2" role="doc-noteref"><a href="#fn:elucidating" class="footnote" rel="footnote">16</a></sup>. Some results that hold for VP diffusion might not hold for VE diffusion or vice versa (without making the requisite changes), and this might not be mentioned explicitly. If you’re reading a diffusion paper, make sure you are aware of which formulation is used, and whether any assumptions are being made about it.</p>
</li>
<li>
<p>Sometimes, the neural network used in a diffusion model is parameterised to <strong>predict the (standardised) noise</strong> added to the input, or the <strong>score function</strong>; sometimes it <strong>predicts the clean input</strong> instead, or even a <strong>time-dependent combination of the two</strong> (as in e.g. \(\mathbf{v}\)-prediction<sup id="fnref:progressive:1" role="doc-noteref"><a href="#fn:progressive" class="footnote" rel="footnote">30</a></sup>). All of these targets are equivalent in the sense that they are time-dependent linear functions of each other and the noisy input \(\mathbf{x}_t\). But it is important to understand how this interacts with the <strong>relative weighting of loss contributions for different time steps</strong> during training, which can significantly affect model performance. Out of the box, predicting the standardised noise seems to be a great choice for image data. When modelling certain other quantities (e.g. latents in latent diffusion), people have found predicting the clean input to work better. This is primarily because it implies a different weighting of noise levels, and hence feature scales.</p>
</li>
<li>
<p>It is generally understood that the standard deviation of the noise added by the corruption process increases with time, i.e. <strong>entropy increases over time</strong>, as it tends to do in our universe. Therefore, \(\mathbf{x}_0\) corresponds to clean data, and \(\mathbf{x}_T\) (for some large enough \(T\)) corresponds to pure noise. Some works (e.g. Flow Matching<sup id="fnref:flowmatching:2" role="doc-noteref"><a href="#fn:flowmatching" class="footnote" rel="footnote">18</a></sup>) invert this convention, which can be very confusing if you don’t notice it straight away.</p>
</li>
</ul>
<p>Finally, it’s worth noting that the definition of “diffusion” in the context of generative modelling has grown to be quite broad, and is now <strong>almost equivalent to “iterative refinement”</strong>. A lot of “diffusion models” for discrete input are not actually based on diffusion processes, but they are of course closely related, so the scope of this label has gradually been extended to include them. It’s not clear where to draw the line: if any model which implements iterative refinement through inversion of a gradual corruption process is a diffusion model, then all autoregressive models are also diffusion models. To me, that seems confusing enough so as to render the term useless.</p>
<h2 id="-closing-thoughts"><a name="closing-thoughts"></a> Closing thoughts</h2>
<figure>
<a href="/images/hawaii.jpg"><img src="/images/hawaii.jpg" /></a>
</figure>
<p>Learning about diffusion models right now must be a pretty confusing experience, but the exploration of all these different perspectives has resulted in a diverse toolbox of methods which can all be combined together, because <strong>ultimately, the underlying model is always the same</strong>. I’ve also found that learning about how the different perspectives relate to each other has considerably deepened my understanding. Some things that are a mystery from one perspective are clear as day in another.</p>
<p>If you are just getting started with diffusion, hopefully this post will help guide you towards the right things to learn next. If you are a seasoned diffuser, I hope I’ve broadened your perspectives and I hope you’ve learnt something new nevertheless. Thanks for reading!</p>
<p style="background-color: #eee; padding: 1.2em; font-weight: bold; margin: 3em 0; text-align: center;">
What's your favourite perspective on diffusion? Are there any useful perspectives that I've missed? Please share your thoughts in the comments below, or reach out on Twitter (<a href="https://twitter.com/sedielem">@sedielem</a>) or Threads (<a href="https://www.threads.net/@sanderdieleman">@sanderdieleman</a>) if you prefer. Email is okay too. <br /><br /> I will also be at ICML 2023 in Honolulu and would be happy to chat in person!</p>
<p><em>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</em></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@misc{dieleman2023perspectives,
author = {Dieleman, Sander},
title = {Perspectives on diffusion},
url = {https://sander.ai/2023/07/20/perspectives.html},
year = {2023}
}
</code></pre></div></div>
<h2 id="-acknowledgements"><a name="acknowledgements"></a> Acknowledgements</h2>
<p>Thanks to my colleagues at Google DeepMind for various discussions, which continue to shape my thoughts on this topic! Thanks to Ayan Das, Ira Korshunova, Peyman Milanfar, and Çağlar Ünlü for suggestions and corrections.</p>
<h2 id="-references"><a name="references"></a> References</h2>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:bengio" role="doc-endnote">
<p>Bengio, Lamblin, Popovici, Larochelle, “<a href="https://proceedings.neurips.cc/paper/2006/hash/5da713a690c067105aeb2fae32403405-Abstract.html">Greedy Layer-Wise Training of Deep Networks</a>”, Neural Information Processing Systems, 2006. <a href="#fnref:bengio" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:bengio:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:noneq" role="doc-endnote">
<p>Sohl-Dickstein, Weiss, Maheswaranathan, Ganguli, “<a href="https://arxiv.org/abs/1503.03585">Deep Unsupervised Learning using Nonequilibrium Thermodynamics</a>”, International Conference on Machine Learning, 2015. <a href="#fnref:noneq" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ddpm" role="doc-endnote">
<p>Ho, Jain, Abbeel, “<a href="https://proceedings.neurips.cc/paper/2020/hash/4c5bcfec8584af0d967f1ab10179ca4b-Abstract.html">Denoising Diffusion Probabilistic Models</a>”, 2020. <a href="#fnref:ddpm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vaekingma" role="doc-endnote">
<p>Kingma and Welling, “<a href="https://arxiv.org/abs/1312.6114">Auto-Encoding Variational Bayes</a>”, International Conference on Learning Representations, 2014. <a href="#fnref:vaekingma" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vaerezende" role="doc-endnote">
<p>Rezende, Mohamed and Wierstra, “<a href="https://arxiv.org/abs/1401.4082">Stochastic Backpropagation and Approximate Inference in Deep Generative Models</a>”, International Conference on Machine Learning, 2014. <a href="#fnref:vaerezende" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:scorematching" role="doc-endnote">
<p>Hyvärinen, “<a href="http://www.jmlr.org/papers/v6/hyvarinen05a.html">Estimation of Non-Normalized Statistical Models by Score Matching</a>”, Journal of Machine Learning Research, 2005. <a href="#fnref:scorematching" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:scorematching:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:ssm" role="doc-endnote">
<p>Song, Garg, Shi, Ermon, “<a href="https://arxiv.org/abs/1905.07088">Sliced Score Matching: A Scalable Approach to Density and Score Estimation</a>”, Uncertainty in Artifical Intelligence, 2019. <a href="#fnref:ssm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:dsm" role="doc-endnote">
<p>Vincent, “<a href="http://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf">A Connection Between Score Matching and Denoising Autoencoders</a>”, Technical report, 2010. <a href="#fnref:dsm" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:dsm:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:songermon" role="doc-endnote">
<p>Song, Ermon, “<a href="https://arxiv.org/abs/1907.05600">Generative Modeling by Estimating Gradients of the Data Distribution</a>”, Neural Information Processing Systems, 2019. <a href="#fnref:songermon" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:anderson" role="doc-endnote">
<p>Anderson, “<a href="https://www.sciencedirect.com/science/article/pii/0304414982900515">Reverse-time diffusion equation models</a>”, Stochastic Processes and their Applications, 1982. <a href="#fnref:anderson" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sde" role="doc-endnote">
<p>Song, Sohl-Dickstein, Kingma, Kumar, Ermon and Poole, “<a href="https://arxiv.org/abs/2011.13456">Score-Based Generative Modeling through Stochastic Differential Equations</a>”, International Conference on Learning Representations, 2021. <a href="#fnref:sde" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:sde:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a> <a href="#fnref:sde:2" class="reversefootnote" role="doc-backlink">↩<sup>3</sup></a></p>
</li>
<li id="fn:nice" role="doc-endnote">
<p>Dinh, Krueger, Bengio, “<a href="https://arxiv.org/abs/1410.8516">NICE: Non-linear Independent Components Estimation</a>”, International Conference on Learning Representations, 2015. <a href="#fnref:nice" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:realnvp" role="doc-endnote">
<p>Dinh, Sohl-Dickstein, Bengio, “<a href="https://arxiv.org/abs/1605.08803">Density estimation using Real NVP</a>”, International Conference on Learning Representations, 2017. <a href="#fnref:realnvp" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:node" role="doc-endnote">
<p>Chen, Rubanova, Bettencourt, Duvenaud, “<a href="https://arxiv.org/abs/1806.07366">Neural Ordinary Differential Equations</a>”, Neural Information Processing Systems, 2018. <a href="#fnref:node" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ffjord" role="doc-endnote">
<p>Grathwohl, Chen, Bettencourt, Sutskever, Duvenaud, “<a href="https://arxiv.org/abs/1810.01367">FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models</a>”, Computer Vision and Pattern Recognition, 2018. <a href="#fnref:ffjord" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:elucidating" role="doc-endnote">
<p>Karras, Aittala, Aila, Laine, “<a href="https://arxiv.org/abs/2206.00364">Elucidating the Design Space of Diffusion-Based Generative Models</a>”, Neural Information Processing Systems, 2022. <a href="#fnref:elucidating" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:elucidating:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a> <a href="#fnref:elucidating:2" class="reversefootnote" role="doc-backlink">↩<sup>3</sup></a></p>
</li>
<li id="fn:cdcd" role="doc-endnote">
<p>Dieleman, Sartran, Roshannai, Savinov, Ganin, Richemond, Doucet, Strudel, Dyer, Durkan, Hawthorne, Leblond, Grathwohl, Adler, “<a href="https://arxiv.org/abs/2211.15089">Continuous diffusion for categorical data</a>”, arXiv, 2022. <a href="#fnref:cdcd" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:cdcd:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:flowmatching" role="doc-endnote">
<p>Lipman, Chen, Ben-Hamu, Nickel, Le, “<a href="https://arxiv.org/abs/2210.02747">Flow Matching for Generative Modeling</a>”, International Conference on Learning Representations, 2023. <a href="#fnref:flowmatching" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:flowmatching:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a> <a href="#fnref:flowmatching:2" class="reversefootnote" role="doc-backlink">↩<sup>3</sup></a></p>
</li>
<li id="fn:rectifiedflow" role="doc-endnote">
<p>Liu, Gong, Liu, “<a href="https://arxiv.org/abs/2209.03003">Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow</a>”, International Conference on Learning Representations, 2023. <a href="#fnref:rectifiedflow" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:rectifiedflow:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:stochasticinterpolants" role="doc-endnote">
<p>Albergo, Vanden-Eijnden, “<a href="https://arxiv.org/abs/2209.15571">Building Normalizing Flows with Stochastic Interpolants</a>”, International Conference on Learning Representations, 2023. <a href="#fnref:stochasticinterpolants" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:stochasticinterpolants:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:ddim" role="doc-endnote">
<p>Song, Meng, Ermon, “<a href="https://arxiv.org/abs/2010.02502">Denoising Diffusion Implicit Models</a>”, International Conference on Learning Representations, 2021. <a href="#fnref:ddim" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:lstm" role="doc-endnote">
<p>Hochreiter, Schmidhuber, “<a href="https://ieeexplore.ieee.org/abstract/document/6795963">Long short-term memory</a>”, Neural Computation, 1997. <a href="#fnref:lstm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sdxl" role="doc-endnote">
<p>Podell, English, Lacey, Blattmann, Dockhorn, Muller, Penna, Rombach, “<a href="https://github.com/Stability-AI/generative-models/blob/main/assets/sdxl_report.pdf">SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis</a>”, tech report, 2023. <a href="#fnref:sdxl" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:dbns" role="doc-endnote">
<p>Hinton, Osindero, Teh, “<a href="https://direct.mit.edu/neco/article-abstract/18/7/1527/7065/A-Fast-Learning-Algorithm-for-Deep-Belief-Nets">A Fast Learning Algorithm for Deep Belief Nets</a>”, Neural Computation, 2006. <a href="#fnref:dbns" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:selfcond" role="doc-endnote">
<p>Chen, Zhang, Hinton, “<a href="https://arxiv.org/abs/2208.04202">Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning</a>”, International Conference on Learning Representations, 2023. <a href="#fnref:selfcond" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:selfcond:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:rin" role="doc-endnote">
<p>Jabri, Fleet, Chen, “<a href="https://arxiv.org/abs/2212.11972">Scalable Adaptive Computation for Iterative Generation</a>”, arXiv, 2022. <a href="#fnref:rin" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:imagestats" role="doc-endnote">
<p>Torralba, Oliva, “<a href="https://iopscience.iop.org/article/10.1088/0954-898X/14/3/302/meta">Statistics of Natural Image Categories</a>”, Network: Computation in Neural Systems, 2003. <a href="#fnref:imagestats" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:heat" role="doc-endnote">
<p>Rissanen, Heinonen, Solin, “<a href="https://arxiv.org/abs/2206.13397">Generative Modelling With Inverse Heat Dissipation</a>”, International Conference on Learning Representations, 2023. <a href="#fnref:heat" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:distillation" role="doc-endnote">
<p>Hinton, Vinyals, Dean, “<a href="https://arxiv.org/abs/1503.02531">Distilling the Knowledge in a Neural Network</a>”, Neural Information Processing Systems, Deep Learning and Representation Learning Workshop, 2015. <a href="#fnref:distillation" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:progressive" role="doc-endnote">
<p>Salimans, Ho, “<a href="https://arxiv.org/abs/2202.00512">Progressive Distillation for Fast Sampling of Diffusion Models</a>”, International Conference on Learning Representations, 2022. <a href="#fnref:progressive" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:progressive:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:guided" role="doc-endnote">
<p>Meng, Rombach, Gao, Kingma, Ermon, Ho, Salimans, “<a href="https://arxiv.org/abs/2210.03142">On Distillation of Guided Diffusion Models</a>”, Computer Vision and Pattern Recognition, 2023. <a href="#fnref:guided" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:tract" role="doc-endnote">
<p>Berthelot, Autef, Lin, Yap, Zhai, Hu, Zheng, Talbott, Gu, “<a href="https://arxiv.org/abs/2303.04248">TRACT: Denoising Diffusion Models with Transitive Closure Time-Distillation</a>”, arXiv, 2023. <a href="#fnref:tract" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:d3pm" role="doc-endnote">
<p>Austin, Johnson, Ho, Tarlow, van den Berg, “<a href="https://arxiv.org/abs/2107.03006">Structured Denoising Diffusion Models in Discrete State-Spaces</a>”, Neural Information Processing Systems, 2021. <a href="#fnref:d3pm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:maskgit" role="doc-endnote">
<p>Chang, Zhang, Jiang, Liu, Freeman, “<a href="https://arxiv.org/abs/2202.04200">MaskGIT: Masked Generative Image Transformer</a>”, Computer Vision and Patern Recognition, 2022. <a href="#fnref:maskgit" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:maskpredict" role="doc-endnote">
<p>Ghazvininejad, Levy, Liu, Zettlemoyer, “<a href="https://arxiv.org/abs/1904.09324">Mask-Predict: Parallel Decoding of Conditional Masked Language Models</a>”, Empirical Methods in Natural Language Processing, 2019. <a href="#fnref:maskpredict" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ardm" role="doc-endnote">
<p>Hoogeboom, Gritsenko, Bastings, Poole, van den Berg, Salimans, “<a href="https://arxiv.org/abs/2110.02037">Autoregressive Diffusion Models</a>”, International Conference on Learning Representations, 2022. <a href="#fnref:ardm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:multinomial" role="doc-endnote">
<p>Hoogeboom, Nielsen, Jaini, Forré, Welling, “<a href="https://arxiv.org/abs/2102.05379">Argmax Flows and Multinomial Diffusion: Learning Categorical Distributions</a>”, Neural Information Processing Systems, 2021. <a href="#fnref:multinomial" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sundae" role="doc-endnote">
<p>Savinov, Chung, Binkowski, Elsen, van den Oord, “<a href="https://arxiv.org/abs/2112.06749">Step-unrolled Denoising Autoencoders for Text Generation</a>”, International Conference on Learning Representations, 2022. <a href="#fnref:sundae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ctmc" role="doc-endnote">
<p>Campbell, Benton, De Bortoli, Rainforth, Deligiannidis, Doucet, “<a href="https://arxiv.org/abs/2205.14987">A continuous time framework for discrete denoising models</a>”, Neural Information Processing Systems, 2022. <a href="#fnref:ctmc" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:discretescore" role="doc-endnote">
<p>Sun, Yu, Dai, Schuurmans, Dai, “<a href="https://arxiv.org/abs/2211.16750">Score-based Continuous-time Discrete Diffusion Models</a>”, International Conference on Learning Representations, 2023. <a href="#fnref:discretescore" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:blackout" role="doc-endnote">
<p>Santos, Fox, Lubbers, Lin, “<a href="https://arxiv.org/abs/2305.11089">Blackout Diffusion: Generative Diffusion Models in Discrete-State Spaces</a>”, International Conference on Machine Learning, 2023. <a href="#fnref:blackout" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sed" role="doc-endnote">
<p>Strudel, Tallec, Altché, Du, Ganin, Mensch, Grathwohl, Savinov, Dieleman, Sifre, Leblond, “<a href="https://arxiv.org/abs/2211.04236">Self-conditioned Embedding Diffusion for Text Generation</a>”, arXiv, 2022. <a href="#fnref:sed" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:indi" role="doc-endnote">
<p>Delbracio, Milanfar, “<a href="https://arxiv.org/abs/2303.11435">Inversion by Direct Iteration: An Alternative to Denoising Diffusion for Image Restoration</a>”, Transactions on Machine Learning Research, 2023. <a href="#fnref:indi" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:deblend" role="doc-endnote">
<p>Heitz, Belcour, Chambon, “<a href="https://arxiv.org/abs/2305.03486">Iterative alpha-(de)Blending: a Minimalist Deterministic Diffusion Model</a>”, SIGGRAPH 2023. <a href="#fnref:deblend" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:cm" role="doc-endnote">
<p>Song, Dhariwal, Chen, Sutskever, “<a href="https://arxiv.org/abs/2303.01469">Consistency Models</a>”, International Conference on Machine Learning, 2023. <a href="#fnref:cm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:cdm" role="doc-endnote">
<p>Daras, Dagan, Dimakis, Daskalakis, “<a href="https://arxiv.org/abs/2302.09057">Consistent Diffusion Models: Mitigating Sampling Drift by Learning to be Consistent</a>”, arXiv, 2023. <a href="#fnref:cdm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:fpdiffusion" role="doc-endnote">
<p>Lai, Takida, Murata, Uesaka, Mitsufuji, Ermon, “<a href="https://arxiv.org/abs/2210.04296">FP-Diffusion: Improving Score-based Diffusion Models by Enforcing the Underlying Score Fokker-Planck Equation</a>”, International Conference on Machine Learning, 2023. <a href="#fnref:fpdiffusion" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:equivalenceconsistency" role="doc-endnote">
<p>Lai, Takida, Uesaka, Murata, Mitsufuji, Ermon, “<a href="https://arxiv.org/abs/2306.00367">On the Equivalence of Consistency-Type Models: Consistency Models, Consistent Diffusion Models, and Fokker-Planck Regularization</a>”, arXiv, 2023. <a href="#fnref:equivalenceconsistency" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Diffusion models appear to come in many shapes and forms. If you pick two random research papers about diffusion and look at how they describe the model class in their respective introductions, chances are they will go about it in very different ways. This can be both frustrating and enlightening: frustrating, because it makes it harder to spot relationships and equivalences across papers and implementations – but also enlightening, because these various perspectives each reveal new connections and are a breeding ground for new ideas. This blog post is an overview of the perspectives on diffusion I’ve found useful.Diffusion language models2023-01-09T00:00:00+00:002023-01-09T00:00:00+00:00https://sander.ai/2023/01/09/diffusion-language<p>Diffusion models have completely taken over generative modelling of perceptual signals such as images, audio and video. Why is autoregression still the name of the game for language modelling? And can we do anything about that? Some thoughts about what it will take for other forms of iterative refinement to take over language modelling, the last bastion of autoregression.</p>
<h2 id="-the-rise-of-diffusion-models"><a name="diffusion"></a> The rise of diffusion models</h2>
<figure>
<a href="/images/diffuse2.jpg"><img src="/images/diffuse2.jpg" /></a>
</figure>
<p>Roughly three years ago, things were starting to look as if adversarial image generators were about to be supplanted by a powerful combination of autoregression and discrete representation learning. <a href="https://arxiv.org/abs/1809.11096">BigGAN</a><sup id="fnref:biggan" role="doc-noteref"><a href="#fn:biggan" class="footnote" rel="footnote">1</a></sup> and <a href="https://arxiv.org/abs/1912.04958">StyleGAN</a><sup id="fnref:stylegan" role="doc-noteref"><a href="#fn:stylegan" class="footnote" rel="footnote">2</a></sup> had significantly expanded the capabilities of image generators, but the mode-seeking nature of GANs made them favour realism over diversity. This presented some challenges, and people were having trouble reproducing impressive domain-specific results (e.g. generating realistic human faces) on more diverse training datasets.</p>
<p><a href="https://arxiv.org/abs/1906.00446">VQ-VAE 2</a><sup id="fnref:vqvae2" role="doc-noteref"><a href="#fn:vqvae2" class="footnote" rel="footnote">3</a></sup> and especially <a href="https://arxiv.org/abs/2012.09841">VQGAN</a><sup id="fnref:vqgan" role="doc-noteref"><a href="#fn:vqgan" class="footnote" rel="footnote">4</a></sup> extolled the virtue of a two-stage approach to generative modelling: first turn everything into a highly compressed discrete one-dimensional sequence, and then learn to predict this sequence step-by-step using a powerful autoregressive model. This idea had already proven fruitful before, going back to the original <a href="https://arxiv.org/abs/1711.00937">VQ-VAE</a><sup id="fnref:vqvae" role="doc-noteref"><a href="#fn:vqvae" class="footnote" rel="footnote">5</a></sup>, but these two papers really drove the point home that this was our best bet for generative modelling of diverse data at scale.</p>
<p>But then, a challenger appeared: a new generative modelling approach based on <strong>iterative denoising</strong> was starting to show promise. Yang Song and Stefano Ermon proposed score-based models: while their <a href="https://arxiv.org/abs/1907.05600">NeurIPS 2019 paper</a><sup id="fnref:songermon" role="doc-noteref"><a href="#fn:songermon" class="footnote" rel="footnote">6</a></sup> was more of a proof-of-concept, the next year’s follow-up <a href="https://arxiv.org/abs/2006.09011">‘Improved Techniques for Training Score-Based Generative Models’</a><sup id="fnref:songermon2" role="doc-noteref"><a href="#fn:songermon2" class="footnote" rel="footnote">7</a></sup> showed results that convinced some people (including me!) to take this direction of research more seriously. Another NeurIPS 2020 paper by Jonathan Ho, Ajay Jain and Pieter Abbeel, <a href="https://arxiv.org/abs/2006.11239">‘Denoising Diffusion Probabilistic Models’ (DDPMs)</a><sup id="fnref:ddpm" role="doc-noteref"><a href="#fn:ddpm" class="footnote" rel="footnote">8</a></sup> showed similar results, and it didn’t take people too long to realise that DDPMs and score-based models were two sides of the same coin.</p>
<p>The real triumph of diffusion models over other alternatives for image generation came in 2021, with <a href="https://arxiv.org/abs/2105.05233">‘Diffusion Models Beat GANs on Image Synthesis’</a><sup id="fnref:beatgans" role="doc-noteref"><a href="#fn:beatgans" class="footnote" rel="footnote">9</a></sup> by Prafulla Dhariwal and Alex Nichol. At that point, it was pretty clear to everyone in the know that this approach was poised to take over. Powerful diffusion-based text-to-image models such as <a href="https://arxiv.org/abs/2112.10741">GLIDE</a><sup id="fnref:glide" role="doc-noteref"><a href="#fn:glide" class="footnote" rel="footnote">10</a></sup> started to arrive by the end of that year, and proceeded to go mainstream in 2022.</p>
<p><strong>If you are unfamiliar with diffusion models, I recommend reading at least the first section of my previous blog post <a href="https://benanne.github.io/2022/01/31/diffusion.html#diffusion">‘Diffusion models are autoencoders’</a> for context, before reading the rest of this one.</strong></p>
<h2 id="-diffusion-for-images-a-match-made-in-heaven"><a name="match"></a> Diffusion for images: a match made in heaven</h2>
<figure>
<a href="/images/noisy_mountains.jpg"><img src="/images/noisy_mountains.jpg" alt="A noisy image of a mountain range, with the level of noise gradually decreasing from left to right." /></a>
</figure>
<p>Diffusion models and the human visual system have one important thing in common: <strong>they don’t care too much about high frequencies</strong>. At least, not out of the box. I discussed the reasons for this in some detail in <a href="https://benanne.github.io/2022/01/31/diffusion.html#scale">an earlier blog post</a> (section 5 in particular).</p>
<p>In a nutshell, the different levels of noise at which a diffusion model operates allow it to focus on different spatial frequency components of the image at each iterative refinement step. When sampling an image, the model effectively builds it up from low frequencies to high frequencies, first filling in large-scale structure and then adding progressively more fine-grained details.</p>
<p>During training, we sample a noise level for each training example, add noise to it, and then try to predict the noise. The relative weights with which we sample the different noise levels therefore determine the degree to which the model focuses on large-scale and fine-grained structure. The most commonly used formulation, with uniform weighting of the noise levels, yields a very different objective than the likelihood loss which e.g. autoregressive models are trained with.</p>
<p>It turns out that there is a particular weighting which corresponds directly to the likelihood loss<sup id="fnref:likelihood" role="doc-noteref"><a href="#fn:likelihood" class="footnote" rel="footnote">11</a></sup>, but this puts significantly more weight on very low noise levels. Since low noise levels correspond to high spatial frequencies, this also indirectly explains why likelihood-based autoregressive models in pixel space never really took off: they end up spending way too much of their capacity on perceptually meaningless detail, and never get around to modelling larger-scale structure.</p>
<p>Relative to the likelihood loss, uniform weighting across noise levels in diffusion models yields an objective that is much more closely aligned with the human visual system. I don’t believe this was actually known when people first started training diffusion models on images – it was just a lucky coincidence! But we understand this pretty well now, and I think it is one of the two main reasons why this modelling approach completely took over in a matter of two years. (The other reason is of course <strong>classifier-free guidance</strong>, which you can read more about in <a href="https://benanne.github.io/2022/05/26/guidance.html">my previous blog post on the topic</a>.)</p>
<p>The reason I bring all this up here, is that <strong>it doesn’t bode particularly well for applications of diffusion models beyond the perceptual domain</strong>. Our ears have a similar disdain for high frequencies as our eyes (though to a lesser extent, I believe), but in the language domain, what does “high frequency” even mean<sup id="fnref:prism" role="doc-noteref"><a href="#fn:prism" class="footnote" rel="footnote">12</a></sup>? Given the success of likelihood-based language models, could the relatively lower weight of low noise levels actually prove to be a liability in this setting?</p>
<h2 id="-autoregression-for-language-a-tough-baseline-to-beat"><a name="ar"></a> Autoregression for language: a tough baseline to beat</h2>
<figure>
<a href="/images/arguidance.jpg"><img src="/images/arguidance.jpg" /></a>
</figure>
<p>Autoregression at the word or token level is a very natural way to do language modelling, because to some degree, it reflects how language is produced and consumed: as a one-dimensional sequence, one element at a time, in a particular fixed order. However, if we consider the process through which an abstract thought turns into an utterance, the iterative denoising metaphor starts to look more appealing. When writing a paragraph, the core concepts are generally decided on first, and the exact wording and phrasing doesn’t materialise until later. That said, perhaps it doesn’t matter precisely how humans interact with language: just like how planes don’t fly the same way birds do (h/t Yann LeCun), <strong>the best way to build a practically useful language model need not reflect nature</strong> either.</p>
<p>Practically speaking, autoregressive models have an interface that is somewhat limited: they can be <em>prompted</em>, i.e. tasked to complete a sequence for which a prefix is given. While this has actually been shown to be reasonably versatile in itself, the ability of non-autoregressive models to fill in the blanks (i.e. be conditioned on something other than a prefix, also known as inpainting in the image domain) is potentially quite useful, and not something that comes naturally to autoregressive models (though it is of course possible to do infilling with autoregressive models<sup id="fnref:middle" role="doc-noteref"><a href="#fn:middle" class="footnote" rel="footnote">13</a></sup>).</p>
<h3 id="training-efficiency">Training efficiency</h3>
<p>If we compare autoregression and diffusion side-by-side as different forms of iterative refinement, the former has the distinct advantage that training can be parallelised trivially across all refinement steps. During autoregressive model training, we obtain a useful gradient signal from all steps in the sampling process. This is not true for diffusion models, where we have to sample a particular noise level for each training example. It is not practical to train on many different noise levels for each example, because that would require multiple forward and backward passes through the model. For autoregression, we get gradients for all sequence steps with just a single forward-backward pass.</p>
<p>As a result, <strong>diffusion model training</strong> is almost certainly significantly <strong>less statistically efficient</strong> than autoregressive model training, and slower convergence implies higher computational requirements.</p>
<h3 id="sampling-efficiency">Sampling efficiency</h3>
<p>Sampling algorithms for diffusion models are very flexible: they allow for sample quality and computational cost to be traded off without retraining, simply by changing the number of sampling steps. This isn’t practical with autoregressive models, where the number of sampling steps is tied directly to the length of the sequence that is to be produced. On the face of it, diffusion models are at an advantage here: perhaps we can get high-quality samples with a number of steps that is significantly lower than the sequence length?</p>
<p>For long enough sequences, this is probably true, but it is important to compare apples to apples. Simply comparing the number of sampling steps across different methods relies on the implicit assumption that all sampling steps have the same cost, and this is not the case. Leaving aside the fact that a single diffusion sampling step can sometimes require multiple forward passes through the model, the cost of an individual forward pass also differs. Autoregressive models can benefit substantially from <em>caching</em>, i.e. re-use of activations computed during previous sampling steps, which significantly reduces the cost of each step. This is not the case for diffusion models, because the level of noise present in the input changes throughout sampling, so each sampling step requires a full forward pass across the entire input.</p>
<p>Therefore, the break-even point at which diffusion sampling becomes more efficient than autoregressive sampling is probably at a number of steps <em>significantly below</em> the length of the sequence. Whether this is actually attainable in practice remains to be seen.</p>
<h3 id="why-bother-with-diffusion-at-all">Why bother with diffusion at all?</h3>
<p>The efficiency disadvantages with respect to autoregressive models might lead one to wonder if diffusion-based language modelling is even worth exploring to begin with. Aside from infilling capabilities and metaphorical arguments, there are a few other reasons why I believe it’s worth looking into:</p>
<ul>
<li>
<p>Unlike autoregressive models, which require restricted connectivity patterns to ensure causality (usually achieved by masking), <strong>diffusion model architectures are completely unconstrained</strong>. This enables a lot more creative freedom, as well as potentially benefiting from architectural patterns that are common in other application domains, such as using pooling and upsampling layers to capture structure at multiple scales. One recent example of such creativity is Recurrent Interface Networks<sup id="fnref:rins" role="doc-noteref"><a href="#fn:rins" class="footnote" rel="footnote">14</a></sup>, whose Perceiver IO-like<sup id="fnref:perceiverio" role="doc-noteref"><a href="#fn:perceiverio" class="footnote" rel="footnote">15</a></sup> structure enables efficient re-use of computation across sampling steps.</p>
</li>
<li>
<p>The <strong>flexibility of the sampling procedure</strong> extends beyond trading off quality against computational cost: it can also be modified to amplify the influence of conditioning signals (e.g. through classifier-free guidance), or to include additional constraints without retraining. Li et al.<sup id="fnref:diffusionlm" role="doc-noteref"><a href="#fn:diffusionlm" class="footnote" rel="footnote">16</a></sup> extensively explore the latter ability for text generation (e.g. controlling sentiment or imposing a particular syntactic structure).</p>
</li>
<li>
<p>Who knows what other perks we might uncover by properly exploring this space? The first few papers on diffusion models for images struggled to match results obtained with more established approaches at the time (i.e. GANs, autoregressive models). Work on diffusion models in new domains could follow the same trajectory – <strong>if we don’t try, we’ll never know</strong>.</p>
</li>
</ul>
<h2 id="-diffusion-for-discrete-data"><a name="discrete"></a> Diffusion for discrete data</h2>
<figure>
<a href="/images/discrete.jpg"><img src="/images/discrete.jpg" /></a>
</figure>
<p>Diffusion models operate on continuous inputs by default. When using the score-based formalism, continuity is a requirement because the score function \(\nabla_\mathbf{x} \log p(\mathbf{x})\) is only defined when \(\mathbf{x}\) is continuous. Language is usually represented as a sequence of discrete tokens, so the standard formulation is not applicable. Broadly speaking, there are two ways to tackle this apparent incompatibility:</p>
<ul>
<li>formulate a <strong>discrete corruption process</strong> as an alternative to Gaussian diffusion;</li>
<li><strong>map discrete inputs to continuous vectors</strong> and apply Gaussian diffusion in that space.</li>
</ul>
<p>The former approach has been explored extensively: D3PM<sup id="fnref:d3pm" role="doc-noteref"><a href="#fn:d3pm" class="footnote" rel="footnote">17</a></sup>, MaskGIT<sup id="fnref:maskgit" role="doc-noteref"><a href="#fn:maskgit" class="footnote" rel="footnote">18</a></sup>, Mask-predict<sup id="fnref:maskpredict" role="doc-noteref"><a href="#fn:maskpredict" class="footnote" rel="footnote">19</a></sup>, ARDM<sup id="fnref:ardm" role="doc-noteref"><a href="#fn:ardm" class="footnote" rel="footnote">20</a></sup>, Multinomial diffusion<sup id="fnref:multinomial" role="doc-noteref"><a href="#fn:multinomial" class="footnote" rel="footnote">21</a></sup>, DiffusER<sup id="fnref:diffuser" role="doc-noteref"><a href="#fn:diffuser" class="footnote" rel="footnote">22</a></sup> and SUNDAE<sup id="fnref:sundae" role="doc-noteref"><a href="#fn:sundae" class="footnote" rel="footnote">23</a></sup> are all different flavours of non-autoregressive iterative refinement using a discrete corruption process. Many (but not all) of these works focus on language modelling as the target application. It should be noted that machine translation has been particularly fertile ground for this line of work, because the strong conditioning signal makes non-autoregressive methods attractive even when their ability to capture diversity is relatively limited. Several works on non-autoregressive machine translation predate the rise of diffusion models.</p>
<p>Unfortunately, moving away from the standard continuous formulation of diffusion models tends to mean giving up on some useful features, such as classifier-free guidance and the ability to use various accelerated sampling algorithms developed specifically for this setting. Luckily, we can stick with continuous Gaussian diffusion simply by <strong>embedding</strong> discrete data in Euclidean space. This approach has recently been explored for language modelling. Some methods, like self-conditioned embedding diffusion (SED)<sup id="fnref:sed" role="doc-noteref"><a href="#fn:sed" class="footnote" rel="footnote">24</a></sup>, use a separate representation learning model to obtain continuous embeddings corresponding to discrete tokens; others jointly fit the embeddings and the diffusion model, like Diffusion-LM<sup id="fnref:diffusionlm:1" role="doc-noteref"><a href="#fn:diffusionlm" class="footnote" rel="footnote">16</a></sup>, CDCD<sup id="fnref:cdcd" role="doc-noteref"><a href="#fn:cdcd" class="footnote" rel="footnote">25</a></sup> and Difformer<sup id="fnref:difformer" role="doc-noteref"><a href="#fn:difformer" class="footnote" rel="footnote">26</a></sup>.</p>
<p><a href="https://arxiv.org/abs/2211.15089"><strong>Continuous diffusion for categorical data (CDCD)</strong></a> is my own work in this space: we set out to explore how diffusion models could be adapted for language modelling. One of the goals behind this research project was to develop a method for diffusion language modelling that looks as familiar as possible to language modelling practitioners. Training diffusion models is a rather different experience from training autoregressive Transformers, and we wanted to <strong>minimise the differences to make this as approachable as possible</strong>. The result is a model whose training procedure is remarkably close to that of BERT<sup id="fnref:bert" role="doc-noteref"><a href="#fn:bert" class="footnote" rel="footnote">27</a></sup>: the input token sequence is embedded, noise is added to the embeddings, and the model learns to predict the original tokens using the cross-entropy loss (<em>score interpolation</em>). The model architecture is a standard Transformer. We address the issue of finding the right weighting for the different noise levels with an active learning strategy (<em>time warping</em>), which adapts the distribution of sampled noise levels on the fly during training.</p>
<p>Another way to do language modelling with Gaussian diffusion, which to my knowledge has not been explored extensively so far, is to <strong>learn higher-level continuous representations</strong> rather than embed individual tokens. This would require a powerful representation learning approach that learns representations that are rich enough to be decoded back into readable text (potentially by a light-weight autoregressive decoder). Autoencoders applied to token sequences tend to produce representations that fail to capture the least predictable components of the input, which carry precisely the most salient information. Perhaps contrastive methods, or methods that try to capture the dynamics of text (such as Time Control<sup id="fnref:timecontrol" role="doc-noteref"><a href="#fn:timecontrol" class="footnote" rel="footnote">28</a></sup>) could be more suitable for this purpose.</p>
<h2 id="-closing-thoughts"><a name="closing-thoughts"></a> Closing thoughts</h2>
<figure>
<a href="/images/sunset2.jpg"><img src="/images/sunset2.jpg" /></a>
</figure>
<p>While CDCD models produce reasonable samples, and are relatively easy to scale due to their similarity to existing language models, the efficiency advantages of autoregression make it a very tough baseline to beat. I believe it is still <strong>too early to consider diffusion as a serious alternative to autoregression for generative language modelling at scale</strong>. As it stands, we also know next to nothing about scaling laws for diffusion models. Perhaps ideas such as latent self-conditioning<sup id="fnref:rins:1" role="doc-noteref"><a href="#fn:rins" class="footnote" rel="footnote">14</a></sup> could make diffusion more competitive, by improving computational efficiency, but it’s not clear that this will be sufficient. Further exploration of this space has the potential to pay off handsomely!</p>
<p>All in all, I have become convinced that the key to powerful generative models is <strong>iterative refinement</strong>: rather than generating a sample in a single pass through a neural network, the model is applied repeatedly to refine a canvas, and hence the unrolled sampling procedure corresponds to a much “deeper” computation graph. Exactly which algorithm one uses to achieve this might not matter too much in the end, whether it be autoregression, diffusion, or something else entirely. I have a lot more thoughts about this, so perhaps this could be the subject of a future blog post.</p>
<p><em>On an unrelated note: I’ve disabled Disqus comments on all of my blog posts, as their ads seem to have gotten very spammy. I don’t have a good alternative to hand right now, so in the meantime, feel free to tweet your thoughts at me instead <a href="https://twitter.com/sedielem">@sedielem</a>, or send me an email. When I eventually revamp this blog at some point in the future, I will look into re-enabling comments. Apologies for the inconvenience!</em></p>
<p><em>UPDATE (April 7): I have reenabled Disqus comments.</em></p>
<p><em>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</em></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@misc{dieleman2023language,
author = {Dieleman, Sander},
title = {Diffusion language models},
url = {https://benanne.github.io/2023/01/09/diffusion-language.html},
year = {2023}
}
</code></pre></div></div>
<h2 id="-acknowledgements"><a name="acknowledgements"></a> Acknowledgements</h2>
<p>Thanks to my collaborators on the CDCD project, and all my colleagues at DeepMind.</p>
<h2 id="-references"><a name="references"></a> References</h2>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:biggan" role="doc-endnote">
<p>Brock, Donahue, Simonyan, “<a href="https://arxiv.org/abs/1809.11096">Large Scale GAN Training for High Fidelity Natural Image Synthesis</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:biggan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:stylegan" role="doc-endnote">
<p>Karras, Laine, Aittala, Hellsten, Lehtinen, Aila, “<a href="https://arxiv.org/abs/1912.04958">Analyzing and Improving the Image Quality of StyleGAN</a>”, Computer Vision and Pattern Recognition, 2020. <a href="#fnref:stylegan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqvae2" role="doc-endnote">
<p>Razavi, van den Oord and Vinyals, “<a href="https://arxiv.org/abs/1906.00446">Generating Diverse High-Fidelity Images with VQ-VAE-2</a>”, Neural Information Processing Systems, 2019. <a href="#fnref:vqvae2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqgan" role="doc-endnote">
<p>Esser, Rombach and Ommer, “<a href="https://arxiv.org/abs/2012.09841">Taming Transformers for High-Resolution Image Synthesis</a>”, Computer Vision and Pattern Recognition, 2021. <a href="#fnref:vqgan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqvae" role="doc-endnote">
<p>van den Oord, Vinyals and Kavukcuoglu, “<a href="https://arxiv.org/abs/1711.00937">Neural Discrete Representation Learning</a>”, Neural Information Processing Systems, 2017. <a href="#fnref:vqvae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:songermon" role="doc-endnote">
<p>Song and Ermon, “<a href="https://arxiv.org/abs/1907.05600">Generative Modeling by Estimating Gradients of the Data Distribution</a>”, Neural Information Processing Systems, 2019. <a href="#fnref:songermon" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:songermon2" role="doc-endnote">
<p>Song and Ermon, “<a href="https://arxiv.org/abs/2006.09011">Improved Techniques for Training Score-Based Generative Models</a>”, Neural Information Processing Systems, 2020. <a href="#fnref:songermon2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ddpm" role="doc-endnote">
<p>Ho, Jain and Abbeel, “<a href="https://arxiv.org/abs/2006.11239">Denoising Diffusion Probabilistic Models</a>”, Neural Information Processing Systems, 2020. <a href="#fnref:ddpm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:beatgans" role="doc-endnote">
<p>Dhariwal, Nichol, “<a href="https://arxiv.org/abs/2105.05233">Diffusion Models Beat GANs on Image Synthesis</a>”, Neural Information Processing Systems, 2021. <a href="#fnref:beatgans" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:glide" role="doc-endnote">
<p>Nichol, Dhariwal, Ramesh, Shyam, Mishkin, McGrew, Sutskever, Chen, “<a href="https://arxiv.org/abs/2112.10741">GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models</a>”, arXiv, 2021. <a href="#fnref:glide" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:likelihood" role="doc-endnote">
<p>Song, Durkan, Murray, Ermon, “<a href="https://arxiv.org/abs/2101.09258">Maximum Likelihood Training of Score-Based Diffusion Models</a>”, Neural Information Processing Systems, 2021. <a href="#fnref:likelihood" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:prism" role="doc-endnote">
<p>Tamkin, Jurafsky, Goodman, “<a href="https://arxiv.org/abs/2011.04823">Language Through a Prism: A Spectral Approach for Multiscale Language Representations</a>”, Neural Information Processing Systems, 2020. <a href="#fnref:prism" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:middle" role="doc-endnote">
<p>Bavarian, Jun, Tezak, Schulman, McLeavey, Tworek, Chen, “<a href="https://arxiv.org/abs/2207.14255">Efficient Training of Language Models to Fill in the Middle</a>”, arXiv, 2022. <a href="#fnref:middle" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:rins" role="doc-endnote">
<p>Jabri, Fleet, Chen, “<a href="https://arxiv.org/abs/2212.11972">Scalable Adaptive Computation for Iterative Generation</a>”, arXiv, 2022. <a href="#fnref:rins" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:rins:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:perceiverio" role="doc-endnote">
<p>Jaegle, Borgeaud, Alayrac, Doersch, Ionescu, Ding, Koppula, Zoran, Brock, Shelhamer, Hénaff, Botvinick, Zisserman, Vinyals, Carreira, “<a href="https://arxiv.org/abs/2107.14795">Perceiver IO: A General Architecture for Structured Inputs & Outputs</a>”, International Conference on Learning Representations, 2022. <a href="#fnref:perceiverio" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:diffusionlm" role="doc-endnote">
<p>Li, Thickstun, Gulrajani, Liang, Hashimoto, “<a href="https://arxiv.org/abs/2205.14217">Diffusion-LM Improves Controllable Text Generation</a>”, Neural Information Processing Systems, 2022. <a href="#fnref:diffusionlm" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:diffusionlm:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:d3pm" role="doc-endnote">
<p>Austin, Johnson, Ho, Tarlow, van den Berg, “<a href="https://arxiv.org/abs/2107.03006">Structured Denoising Diffusion Models in Discrete State-Spaces</a>”, Neural Information Processing Systems, 2021. <a href="#fnref:d3pm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:maskgit" role="doc-endnote">
<p>Chang, Zhang, Jiang, Liu, Freeman, “<a href="https://arxiv.org/abs/2202.04200">MaskGIT: Masked Generative Image Transformer</a>”, Computer Vision and Patern Recognition, 2022. <a href="#fnref:maskgit" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:maskpredict" role="doc-endnote">
<p>Ghazvininejad, Levy, Liu, Zettlemoyer, “<a href="https://arxiv.org/abs/1904.09324">Mask-Predict: Parallel Decoding of Conditional Masked Language Models</a>”, Empirical Methods in Natural Language Processing, 2019. <a href="#fnref:maskpredict" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ardm" role="doc-endnote">
<p>Hoogeboom, Gritsenko, Bastings, Poole, van den Berg, Salimans, “<a href="https://arxiv.org/abs/2110.02037">Autoregressive Diffusion Models</a>”, International Conference on Learning Representations, 2022. <a href="#fnref:ardm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:multinomial" role="doc-endnote">
<p>Hoogeboom, Nielsen, Jaini, Forré, Welling, “<a href="https://arxiv.org/abs/2102.05379">Argmax Flows and Multinomial Diffusion: Learning Categorical Distributions</a>”, Neural Information Processing Systems, 2021. <a href="#fnref:multinomial" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:diffuser" role="doc-endnote">
<p>Reid, Hellendoorn, Neubig, “<a href="https://arxiv.org/abs/2210.16886">DiffusER: Discrete Diffusion via Edit-based Reconstruction</a>”, arXiv, 2022. <a href="#fnref:diffuser" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sundae" role="doc-endnote">
<p>Savinov, Chung, Binkowski, Elsen, van den Oord, “<a href="https://arxiv.org/abs/2112.06749">Step-unrolled Denoising Autoencoders for Text Generation</a>”, International Conference on Learning Representations, 2022. <a href="#fnref:sundae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sed" role="doc-endnote">
<p>Strudel, Tallec, Altché, Du, Ganin, Mensch, Grathwohl, Savinov, Dieleman, Sifre, Leblond, “<a href="https://arxiv.org/abs/2211.04236">Self-conditioned Embedding Diffusion for Text Generation</a>”, arXiv, 2022. <a href="#fnref:sed" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:cdcd" role="doc-endnote">
<p>Dieleman, Sartran, Roshannai, Savinov, Ganin, Richemond, Doucet, Strudel, Dyer, Durkan, Hawthorne, Leblond, Grathwohl, Adler, “<a href="https://arxiv.org/abs/2211.15089">Continuous diffusion for categorical data</a>”, arXiv, 2022. <a href="#fnref:cdcd" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:difformer" role="doc-endnote">
<p>Gao, Guo, Tan, Zhu, Zhang, Bian, Xu, “<a href="https://arxiv.org/abs/2212.09412">Difformer: Empowering Diffusion Model on Embedding Space for Text Generation</a>”, arXiv, 2022. <a href="#fnref:difformer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:bert" role="doc-endnote">
<p>Devlin, Chang, Lee, Toutanova, “<a href="https://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a>”, North American Chapter of the Association for Computational Linguistics, 2019. <a href="#fnref:bert" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:timecontrol" role="doc-endnote">
<p>Wang, Durmus, Goodman, Hashimoto, “<a href="https://arxiv.org/abs/2203.11370">Language modeling via stochastic processes</a>”, International Conference on Learning Representations, 2022. <a href="#fnref:timecontrol" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Diffusion models have completely taken over generative modelling of perceptual signals such as images, audio and video. Why is autoregression still the name of the game for language modelling? And can we do anything about that? Some thoughts about what it will take for other forms of iterative refinement to take over language modelling, the last bastion of autoregression.Guidance: a cheat code for diffusion models2022-05-26T00:00:00+01:002022-05-26T00:00:00+01:00https://sander.ai/2022/05/26/guidance<p>Classifier-free diffusion guidance<sup id="fnref:cf" role="doc-noteref"><a href="#fn:cf" class="footnote" rel="footnote">1</a></sup> dramatically improves samples produced by conditional diffusion models at almost no cost. It is simple to implement and extremely effective. It is also an essential component of <a href="https://openai.com/dall-e-2/">OpenAI’s DALL·E 2</a><sup id="fnref:dalle2" role="doc-noteref"><a href="#fn:dalle2" class="footnote" rel="footnote">2</a></sup> and <a href="https://imagen.research.google/">Google’s Imagen</a><sup id="fnref:imagen" role="doc-noteref"><a href="#fn:imagen" class="footnote" rel="footnote">3</a></sup>, powering their spectacular image generation results. In this blog post, I share my perspective and try to give some intuition about how it works.</p>
<h2 id="-diffusion-guidance"><a name="guidance"></a> Diffusion guidance</h2>
<figure>
<a href="/images/diffuse2.jpg"><img src="/images/diffuse2.jpg" /></a>
</figure>
<p>Barely two years ago, they were a niche interest on the fringes of generative modelling research, but today, <strong>diffusion models</strong> are the go-to model class for image and audio generation. In <a href="https://benanne.github.io/2022/01/31/diffusion.html">my previous blog post</a>, I discussed the link between diffusion models and autoencoders. <strong>If you are unfamiliar with diffusion models, I recommend reading at least <a href="https://benanne.github.io/2022/01/31/diffusion.html#diffusion">the first section of that post</a> for context, before reading the rest of this one.</strong></p>
<p>Diffusion models are generative models, which means they model a high-dimensional data distribution \(p(x)\). Rather than trying to approximate \(p(x)\) directly (which is what likelihood-based models do), they try to predict the so-called <em>score function</em>, \(\nabla_x \log p(x)\).</p>
<p>To sample from a diffusion model, an input is initialised to random noise, and is then iteratively denoised by taking steps in the direction of the score function (i.e. the direction in which the log-likelihood increases fastest), with some additional noise mixed in to avoid getting stuck in modes of the distribution. This is called <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_Langevin_dynamics">Stochastic Gradient Langevin Dynamics (SGLD)</a>. This is a bit of a caricature of what people actually use in practice nowadays, but it’s not too far off the truth.</p>
<p>In conditional diffusion models, we have an additional input \(y\) (for example, a class label or a text sequence) and we try to model the conditional distribution \(p(x \mid y)\) instead. In practice, this means learning to predict the conditional score function \(\nabla_x \log p(x \mid y)\).</p>
<p>One neat aspect of the score function is that it is invariant to normalisation of the distribution: if we only know the distribution \(p(x)\) up to a constant, i.e. we have \(p(x) = \frac{\tilde{p}(x)}{Z}\) and we only know \(\tilde{p}(x)\), then we can still compute the score function:</p>
\[\nabla_x \log \tilde{p}(x) = \nabla_x \log \left( p(x) \cdot Z \right) = \nabla_x \left( \log p(x) + \log Z \right) = \nabla_x \log p(x),\]
<p>where we have made use of the linearity of the gradient operator, and the fact that the normalisation constant \(Z = \int \tilde{p}(x) \mathrm{d} x\) does not depend on \(x\) (so its derivative w.r.t. \(x\) is zero).</p>
<p>Unnormalised probability distributions come up all the time, so this is a useful property. For conditional models, it enables us to apply <a href="https://en.wikipedia.org/wiki/Bayes%27_theorem">Bayes’ rule</a> to decompose the score function into an unconditional component, and a component that “mixes in” the conditioning information:</p>
\[p(x \mid y) = \frac{p(y \mid x) \cdot p(x)}{p(y)}\]
\[\implies \log p(x \mid y) = \log p(y \mid x) + \log p(x) - \log p(y)\]
\[\implies \nabla_x \log p(x \mid y) = \nabla_x \log p(y \mid x) + \nabla_x \log p(x) ,\]
<p>where we have used that \(\nabla_x \log p(y) = 0\). In other words, we can obtain the conditional score function as simply the sum of the unconditional score function and a conditioning term. (Note that the conditioning term \(\nabla_x \log p(y \mid x)\) is not itself a score function, because the gradient is w.r.t. \(x\), not \(y\).)</p>
<p><small>Throughout this blog post, I have mostly ignored the <em>time dependency</em> of the distributions estimated by diffusion models. This saves me having to add extra conditioning variables and subscripts everywhere. In practice, diffusion models perform iterative denoising, and are therefore usually conditioned on the level of input noise at each step.</small></p>
<h2 id="-classifier-guidance"><a name="classifier"></a> Classifier guidance</h2>
<figure>
<a href="/images/sorted.jpg"><img src="/images/sorted.jpg" /></a>
</figure>
<p>The first thing to notice is that \(p(y \mid x)\) is exactly what classifiers and other discriminative models try to fit: \(x\) is some high-dimensional input, and \(y\) is a target label. If we have a differentiable discriminative model that estimates \(p(y \mid x)\), then we can also easily obtain \(\nabla_x \log p(y \mid x)\). <strong>All we need to turn an unconditional diffusion model into a conditional one, is a classifier!</strong></p>
<p>The observation that diffusion models can be conditioned <em>post-hoc</em> in this way was mentioned by Sohl-Dickstein et al.<sup id="fnref:equilibrium" role="doc-noteref"><a href="#fn:equilibrium" class="footnote" rel="footnote">4</a></sup> and Song et al.<sup id="fnref:sde" role="doc-noteref"><a href="#fn:sde" class="footnote" rel="footnote">5</a></sup>, but Dhariwal and Nichol<sup id="fnref:beatgans" role="doc-noteref"><a href="#fn:beatgans" class="footnote" rel="footnote">6</a></sup> really drove this point home, and showed how <em>classifier guidance</em> can dramatically improve sample quality by enhancing the conditioning signal, even when used in combination with traditional conditional modelling. To achieve this, they <strong>scale the conditioning term</strong> by a factor:</p>
\[\nabla_x \log p_\gamma(x \mid y) = \nabla_x \log p(x) + \gamma \nabla_x \log p(y \mid x) .\]
<p>\(\gamma\) is called the <strong>guidance scale</strong>, and cranking it up beyond 1 has the effect of <strong>amplifying the influence of the conditioning signal</strong>. It is <em>extremely</em> effective, especially compared to e.g. the truncation trick for GANs<sup id="fnref:biggan" role="doc-noteref"><a href="#fn:biggan" class="footnote" rel="footnote">7</a></sup>, which serves a similar purpose.</p>
<figure>
<a href="/images/classifier_guidance.jpg"><img src="/images/classifier_guidance.jpg" alt="Samples from an unconditional diffusion model with classifier guidance, for guidance scales 1.0 (left) and 10.0 (right), taken from Dhariwal & Nichol (2021).'" /></a>
<figcaption>Samples from an unconditional diffusion model with classifier guidance, for guidance scales 1.0 (left) and 10.0 (right), taken from Dhariwal & Nichol (2021).</figcaption>
</figure>
<p>If we revert the gradient and the logarithm operations that we used to go from Bayes’ rule to classifier guidance, it’s easier to see what’s going on:</p>
\[p_\gamma(x \mid y) \propto p(x) \cdot p(y \mid x)^\gamma .\]
<p>We are raising the conditional part of the distribution to a power, which corresponds to <strong>tuning the temperature</strong> of that distribution: \(\gamma\) is an inverse temperature parameter. If \(\gamma > 1\), this sharpens the distribution and focuses it onto its modes, by shifting probability mass from the least likely to the most likely values (i.e. the temperature is lowered). Classifier guidance allows us to apply this temperature tuning only to the part of the distribution that captures the influence of the conditioning signal.</p>
<p>In language modelling, it is now commonplace to train a powerful unconditional language model once, and then adapt it to downstream tasks as needed (via few-shot learning or finetuning). Superficially, it would seem that classifier guidance enables the same thing for image generation: one could train a powerful unconditional model, then condition it as needed at test time using a separate classifier.</p>
<p>Unfortunately there are a few snags that make this impractical. Most importantly, because diffusion models operate by gradually denoising inputs, any classifier used for guidance also needs to be able to cope with high noise levels, so that it can provide a useful signal all the way through the sampling process. This usually requires training a bespoke classifier specifically for the purpose of guidance, and at that point, it might be easier to train a traditional conditional generative model end-to-end (or at least finetune an unconditional model to incorporate the conditioning signal).</p>
<p>But even if we have a noise-robust classifier on hand, classifier guidance is inherently limited in its effectiveness: most of the information in the input \(x\) is not relevant to predicting \(y\), and as a result, taking the gradient of the classifier w.r.t. its input can yield arbitrary (and even adversarial) directions in input space.</p>
<h2 id="-classifier-free-guidance"><a name="classifier-free"></a> Classifier-free guidance</h2>
<figure>
<a href="/images/compass.jpg"><img src="/images/compass.jpg" /></a>
</figure>
<p>This is where <strong>classifier-free guidance</strong><sup id="fnref:cf:1" role="doc-noteref"><a href="#fn:cf" class="footnote" rel="footnote">1</a></sup> comes in. As the name implies, it does not require training a separate classifier. Instead, one trains a conditional diffusion model \(p(x \mid y)\), with <em>conditioning dropout</em>: some percentage of the time, the conditioning information \(y\) is removed (10-20% tends to work well). In practice, it is often replaced with a special input value representing the absence of conditioning information. The resulting model is now able to function both as a conditional model \(p(x \mid y)\), and as an unconditional model \(p(x)\), depending on whether the conditioning signal is provided. One might think that this comes at a cost to conditional modelling performance, but the effect seems to be negligible in practice.</p>
<p>What does this buy us? Recall <strong>Bayes’ rule</strong> from before, but let’s apply it <strong>in the other direction</strong>:</p>
\[p(y \mid x) = \frac{p(x \mid y) \cdot p(y)}{p(x)}\]
\[\implies \log p(y \mid x) = \log p(x \mid y) + \log p(y) - \log p(x)\]
\[\implies \nabla_x \log p(y \mid x) = \nabla_x \log p(x \mid y) - \nabla_x \log p(x) .\]
<p>We have expressed the conditioning term as a function of the conditional and unconditional score functions, both of which our diffusion model provides. We can now substitute this into the formula for classifier guidance:</p>
\[\nabla_x \log p_\gamma(x \mid y) = \nabla_x \log p(x) + \gamma \left( \nabla_x \log p(x \mid y) - \nabla_x \log p(x) \right),\]
<p>or equivalently:</p>
\[\nabla_x \log p_\gamma(x \mid y) = (1 - \gamma) \nabla_x \log p(x) + \gamma \nabla_x \log p(x \mid y) .\]
<p>This is a <a href="https://people.eecs.ku.edu/~jrmiller/Courses/VectorGeometry/AffineTransformations.html">barycentric combination</a> of the conditional and the unconditional score function. For \(\gamma = 0\), we recover the unconditional model, and for \(\gamma = 1\) we get the standard conditional model. But \(\gamma > 1\) is where the magic happens. Below are some examples from OpenAI’s GLIDE model<sup id="fnref:glide" role="doc-noteref"><a href="#fn:glide" class="footnote" rel="footnote">8</a></sup>, obtained using classifier-free guidance.</p>
<figure>
<a href="/images/panda1.jpg"><img src="/images/panda1.jpg" alt="GLIDE sample with guidance scale 1: 'A stained glass window of a panda eating bamboo.'" width="47%" /></a>
<a href="/images/panda3.jpg"><img src="/images/panda3.jpg" alt="GLIDE sample with guidance scale 3: 'A stained glass window of a panda eating bamboo.'" width="47%" /></a>
<figcaption>Two sets of samples from OpenAI's GLIDE model, for the prompt <i>'A stained glass window of a panda eating bamboo.'</i>, taken from <a href="https://arxiv.org/abs/2112.10741">their paper</a>. Guidance scale 1 (no guidance) on the left, guidance scale 3 on the right.</figcaption>
</figure>
<figure>
<a href="/images/corgi1.jpg"><img src="/images/corgi1.jpg" alt="GLIDE sample with guidance scale 1: '“A cozy living room with a painting of a corgi on the wall above a couch and a round coffee table in front of a couch and a vase of flowers on a coffee table.'" width="47%" /></a>
<a href="/images/corgi3.jpg"><img src="/images/corgi3.jpg" alt="GLIDE sample with guidance scale 3: '“A cozy living room with a painting of a corgi on the wall above a couch and a round coffee table in front of a couch and a vase of flowers on a coffee table.'" width="47%" /></a>
<figcaption>Two sets of samples from OpenAI's GLIDE model, for the prompt <i>'“A cozy living room with a painting of a corgi on the wall above a couch and a round coffee table in front of a couch and a vase of flowers on a coffee table.'</i>, taken from <a href="https://arxiv.org/abs/2112.10741">their paper</a>. Guidance scale 1 (no guidance) on the left, guidance scale 3 on the right.</figcaption>
</figure>
<p>Why does this work so much better than classifier guidance? The main reason is that we’ve constructed the “classifier” from a generative model. Whereas standard classifiers can take shortcuts and ignore most of the input \(x\) while still obtaining competitive classification results, generative models are afforded no such luxury. This makes the resulting gradient much more robust. As a bonus, we only have to train a single (generative) model, and conditioning dropout is trivial to implement.</p>
<p>It is worth noting that there was only a very brief window of time between the publication of the classifier-free guidance idea, and OpenAI’s GLIDE model, which used it to great effect – so much so that the idea has sometimes been attributed to the latter! Simple yet powerful ideas tend to see rapid adoption. In terms of power-to-simplicity ratio, classifier-free guidance is up there with dropout<sup id="fnref:dropout" role="doc-noteref"><a href="#fn:dropout" class="footnote" rel="footnote">9</a></sup>, in my opinion: a real game changer!</p>
<p><small>(In fact, the GLIDE paper says that they originally trained a text-conditional model, and applied conditioning dropout only in a finetuning phase. Perhaps there is a good reason to do it this way, but I rather suspect that this is simply because they decided to apply the idea to a model they had already trained before!)</small></p>
<p>Clearly, guidance represents a trade-off: it dramatically improves adherence to the conditioning signal, as well as overall sample quality, but <strong>at great cost to diversity</strong>. In conditional generative modelling, this is usually an acceptable trade-off, however: the conditioning signal often already captures most of the variability that we actually care about, and if we desire diversity, we can also simply modify the conditioning signal we provide.</p>
<h2 id="-guidance-for-autoregressive-models"><a name="autoregressive"></a> Guidance for autoregressive models</h2>
<figure>
<a href="/images/arguidance.jpg"><img src="/images/arguidance.jpg" /></a>
</figure>
<p>Is guidance unique to diffusion models? On the face of it, not really. People have pointed out that you can do similar things with other model classes:</p>
<blockquote class="twitter-tweet"><p lang="en" dir="ltr">You can apply a similar trick to classifier-free guidance to autoregressive transformers to sample from a synthetic "super-conditioned" distribution. I trained a CIFAR-10 class-conditional ImageGPT to try this, and I got the following grids with cond_scale 1 (default) and then 3: <a href="https://t.co/gWL5sOqXck">pic.twitter.com/gWL5sOqXck</a></p>— Rivers Have Wings (@RiversHaveWings) <a href="https://twitter.com/RiversHaveWings/status/1478093658716966912?ref_src=twsrc%5Etfw">January 3, 2022</a></blockquote>
<script async="" src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
<p>You can train autoregressive models with conditioning dropout just as easily, and then use two sets of logits produced with and without conditioning to construct classifier-free guided logits, just as we did before with score functions. Whether we apply this operation to log-probabilities or gradients of log-probabilities doesn’t really make a difference, because the gradient operator is linear.</p>
<p><strong>There is an important difference however</strong>: whereas the score function in a diffusion model represents the joint distribution across all components of \(x\), \(p(x \mid y)\), the logits produced by autoregressive models represent \(p(x_t \mid x_{<t}, y)\), the <strong>sequential conditional distributions</strong>. You can obtain a joint distribution \(p(x \mid y)\) from this by multiplying all the conditionals together:</p>
\[p(x \mid y) = \prod_{t=1}^T p(x_t \mid x_{<t}, y),\]
<p>but guidance on each of the factors of this product is <strong>not equivalent to applying it to the joint distribution</strong>, as one does in diffusion models:</p>
\[p_\gamma(x \mid y) \neq \prod_{t=1}^T p_\gamma(x_t \mid x_{<t}, y).\]
<p>To see this, let’s first expand the left hand side:</p>
\[p_\gamma(x \mid y) = \frac{p(x) \cdot p(y \mid x)^\gamma}{\int p(x) \cdot p(y \mid x)^\gamma \mathrm{d} x},\]
<p>from which we can divide out the unconditional distribution \(p(x)\) to obtain an input-dependent scale factor that adapts the probabilities based on the conditioning signal \(y\):</p>
\[s_\gamma(x, y) := \frac{p(y \mid x)^\gamma}{\mathbb{E}_{p(x)}\left[ p(y \mid x)^\gamma \right]} .\]
<p>Now we can do the same thing with the right hand side:</p>
\[\prod_{t=1}^T p_\gamma(x_t \mid x_{<t}, y) = \prod_{t=1}^T \frac{p(x_t \mid x_{<t}) \cdot p(y \mid x_{\le t})^\gamma}{\int p(x_t \mid x_{<t}) \cdot p(y \mid x_{\le t})^\gamma \mathrm{d} x_t}\]
<p>We can again factor out \(p(x)\) here:</p>
\[\prod_{t=1}^T p_\gamma(x_t \mid x_{<t}, y) = p(x) \cdot \prod_{t=1}^T \frac{p(y \mid x_{\le t})^\gamma}{\int p(x_t \mid x_{<t}) \cdot p(y \mid x_{\le t})^\gamma \mathrm{d} x_t}.\]
<p>The input-dependent scale factor is now:</p>
\[s_\gamma'(x, y) := \prod_{t=1}^T \frac{p(y \mid x_{\le t})^\gamma}{ \mathbb{E}_{p(x_t \mid x_{<t})} \left[ p(y \mid x_{\le t})^\gamma \right] },\]
<p>which is clearly not equivalent to \(s_\gamma(x, y)\). In other words, guidance on the sequential conditionals redistributes the probability mass in a different way than guidance on the joint distribution does.</p>
<p>I don’t think this has been extensively tested at this point, but my hunch is that diffusion guidance works so well precisely because we are able to apply it to the joint distribution, rather than to individual sequential conditional distributions. As of today, <strong>diffusion models are the only model class for which this approach is tractable</strong> (if there are others, I’d be very curious to learn about them, so please share in the comments!).</p>
<p><small>As an aside: if you have an autoregressive model where the underlying data can be treated as continuous (e.g. an autoregressive model of images like PixelCNN<sup id="fnref:pixelcnn" role="doc-noteref"><a href="#fn:pixelcnn" class="footnote" rel="footnote">10</a></sup> or an Image Transformer<sup id="fnref:imagetransformer" role="doc-noteref"><a href="#fn:imagetransformer" class="footnote" rel="footnote">11</a></sup>), you can actually get gradients w.r.t. the input. This means you can get an efficient estimate of the score function \(\nabla_x \log p(x|y)\) and sample from the model using Langevin dynamics, so you could in theory apply classifier or classifier-free guidance to the joint distribution, in a way that’s equivalent to diffusion guidance!</small></p>
<hr />
<p><strong>Update / correction (May 29th)</strong></p>
<p><a href="https://twitter.com/RiversHaveWings/status/1530563830094262273">@RiversHaveWings on Twitter</a> pointed out that the distributions which we modify to apply guidance are \(p_t(x \mid y)\) (where \(t\) is the current timestep in the diffusion process), not \(p(x \mid y)\) (which is equivalent to \(p_0(x \mid y)\)). This is clearly a shortcoming of the notational shortcut I took throughout this blog post (i.e. making the time dependency implicit).</p>
<p>This calls into question my claim above that diffusion model guidance operates on the true joint distribution of the data – though it doesn’t change the fact that guidance does a different thing for autoregressive models and for diffusion models. As ever in deep learning, whether the difference is meaningful in practice will probably have to be established empirically, so it will be interesting to see if classifier-free guidance catches on for other model classes as well!</p>
<hr />
<h2 id="-temperature-tuning-for-diffusion-models"><a name="temperature"></a> Temperature tuning for diffusion models</h2>
<figure>
<a href="/images/temperature.jpg"><img src="/images/temperature.jpg" /></a>
</figure>
<p>One thing people often do with autoregressive models is tune the temperature of the sequential conditional distributions. More intricate procedures to “shape” these distributions are also popular: top-k sampling, nucleus sampling<sup id="fnref:nucleus" role="doc-noteref"><a href="#fn:nucleus" class="footnote" rel="footnote">12</a></sup> and typical sampling<sup id="fnref:typical" role="doc-noteref"><a href="#fn:typical" class="footnote" rel="footnote">13</a></sup> are the main contenders. They are harder to generalise to high-dimensional distributions, so I won’t consider them here.</p>
<p><strong>Can we tune the temperature of a diffusion model?</strong> Sure: instead of factorising \(p(x \mid y)\) and only modifying the conditional component, we can just raise the whole thing to the \(\gamma\)‘th power simply by multiplying the score function with \(\gamma\). Unfortunately, this invariably yields terrible results. While tuning temperatures of the sequential conditionals in autoregressive models works quite well, and often yields better results, tuning the temperature of the joint distribution seems to be pretty much useless (let me know in the comments if your experience differs!).</p>
<p>Just as with guidance, this is because changing the temperature of the sequential conditionals is <strong>not the same</strong> as changing the temperature of the joint distribution. Working this out is left as an excerise to the reader :)</p>
<p>Note that they do become equivalent when all \(x_t\) are independent (i.e. \(p(x_t \mid x_{<t}) = p(x_t)\)), but if that is the case, using an autoregressive model kind of defeats the point!</p>
<h2 id="-closing-thoughts"><a name="thoughts"></a> Closing thoughts</h2>
<figure>
<a href="/images/sunset2.jpg"><img src="/images/sunset2.jpg" /></a>
</figure>
<p>Guidance is far from the only reason why diffusion models work so well for images: the standard loss function for diffusion de-emphasises low noise levels, relative to the likelihood loss<sup id="fnref:likelihood" role="doc-noteref"><a href="#fn:likelihood" class="footnote" rel="footnote">14</a></sup>. As I mentioned in <a href="https://benanne.github.io/2022/01/31/diffusion.html#scale">my previous blog post</a>, noise levels and image feature scales are closely tied together, and the result is that diffusion models pay less attention to high-frequency content that isn’t visually salient to humans anyway, enabling them to use their capacity more efficiently.</p>
<p>That said, I think guidance is probably the main driver behind the spectacular results we’ve seen over the course of the past six months. I believe guidance constitutes <strong>a real step change in our ability to generate perceptual signals</strong>, going far beyond the steady progress of the last few years that this domain has seen. It is striking that the state-of-the-art models in this domain are able to do what they do, while still being one to two orders of magnitude smaller than state-of-the-art language models in terms of parameter count.</p>
<p>I also believe we’ve only scratched the surface of what’s possible with diffusion models’ steerable sampling process. <em>Dynamic thresholding</em>, introduced this week in the Imagen paper<sup id="fnref:imagen:1" role="doc-noteref"><a href="#fn:imagen" class="footnote" rel="footnote">3</a></sup>, is another simple guidance-enhancing trick to add to our arsenal, and I think there are many more such tricks to be discovered (as well as more elaborate schemes). Guidance seems like it might also enable a kind of “arithmetic” in the image domain like we’ve seen with word embeddings.</p>
<p><em>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</em></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@misc{dieleman2022guidance,
author = {Dieleman, Sander},
title = {Guidance: a cheat code for diffusion models},
url = {https://benanne.github.io/2022/05/26/guidance.html},
year = {2022}
}
</code></pre></div></div>
<h2 id="-acknowledgements"><a name="acknowledgements"></a> Acknowledgements</h2>
<p>Thanks to my colleagues at DeepMind for various discussions, which continue to shape my thoughts on this topic!</p>
<h2 id="-references"><a name="references"></a> References</h2>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:cf" role="doc-endnote">
<p>Ho, Salimans, “<a href="https://openreview.net/forum?id=qw8AKxfYbI">Classifier-Free Diffusion Guidance</a>”, NeurIPS workshop on DGMs and Applications”, 2021. <a href="#fnref:cf" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:cf:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:dalle2" role="doc-endnote">
<p>Ramesh, Dhariwal, Nichol, Chu, Chen, “<a href="https://arxiv.org/abs/2204.06125">Hierarchical Text-Conditional Image Generation with CLIP Latents</a>”, arXiv, 2022. <a href="#fnref:dalle2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:imagen" role="doc-endnote">
<p>Saharia, Chan, Saxena, Li, Whang, Ho, Fleet, Norouzi et al., “<a href="https://arxiv.org/abs/2205.11487">Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding</a>”, arXiv, 2022. <a href="#fnref:imagen" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:imagen:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:equilibrium" role="doc-endnote">
<p>Sohl-Dickstein, Weiss, Maheswaranathan and Ganguli, “<a href="https://arxiv.org/abs/1503.03585">Deep Unsupervised Learning using Nonequilibrium Thermodynamics</a>”, International Conference on Machine Learning, 2015. <a href="#fnref:equilibrium" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sde" role="doc-endnote">
<p>Song, Sohl-Dickstein, Kingma, Kumar, Ermon and Poole, “<a href="https://arxiv.org/abs/2011.13456">Score-Based Generative Modeling through Stochastic Differential Equations</a>”, International Conference on Learning Representations, 2021. <a href="#fnref:sde" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:beatgans" role="doc-endnote">
<p>Dhariwal, Nichol, “<a href="https://arxiv.org/abs/2105.05233">Diffusion Models Beat GANs on Image Synthesis</a>”, Neural Information Processing Systems, 2021. <a href="#fnref:beatgans" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:biggan" role="doc-endnote">
<p>Brock, Donahue, Simonyan, “<a href="https://arxiv.org/abs/1809.11096">Large Scale GAN Training for High Fidelity Natural Image Synthesis</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:biggan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:glide" role="doc-endnote">
<p>Nichol, Dhariwal, Ramesh, Shyam, Mishkin, McGrew, Sutskever, Chen, “<a href="https://arxiv.org/abs/2112.10741">GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models</a>”, arXiv, 2021. <a href="#fnref:glide" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:dropout" role="doc-endnote">
<p>Srivastava, Hinton, Krizhevsky, Sutskever, Salakhutdinov, “<a href="https://jmlr.org/papers/v15/srivastava14a.html">Dropout: A Simple Way to Prevent Neural Networks from Overfitting</a>”, Journal of Machine Learning Research, 2014. <a href="#fnref:dropout" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:pixelcnn" role="doc-endnote">
<p>Van den Oord, Kalchbrenner, Kavukcuoglu, “<a href="https://arxiv.org/abs/1601.06759">Pixel Recurrent Neural Networks</a>”, International Conference on Machine Learning, 2016. <a href="#fnref:pixelcnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:imagetransformer" role="doc-endnote">
<p>Parmar, Vaswani, Uszkoreit, Kaiser, Shazeer, Ku, Tran, “<a href="http://proceedings.mlr.press/v80/parmar18a.html">Image Transformer</a>”, International Conference on Machine Learning, 2018. <a href="#fnref:imagetransformer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:nucleus" role="doc-endnote">
<p>Holtzman, Buys, Du, Forbes, Choi, “<a href="https://arxiv.org/abs/1904.09751">The Curious Case of Neural Text Degeneration</a>”, International Conference on Learning Representations, 2020. <a href="#fnref:nucleus" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:typical" role="doc-endnote">
<p>Meister, Pimentel, Wiher, Cotterell, “<a href="https://arxiv.org/abs/2202.00666">Typical Decoding for Natural Language Generation</a>”, arXiv, 2022. <a href="#fnref:typical" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:likelihood" role="doc-endnote">
<p>Song, Durkan, Murray, Ermon, “<a href="https://arxiv.org/abs/2101.09258">Maximum Likelihood Training of Score-Based Diffusion Models</a>”, Neural Information Processing Systems, 2021 <a href="#fnref:likelihood" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Classifier-free diffusion guidance1 dramatically improves samples produced by conditional diffusion models at almost no cost. It is simple to implement and extremely effective. It is also an essential component of OpenAI’s DALL·E 22 and Google’s Imagen3, powering their spectacular image generation results. In this blog post, I share my perspective and try to give some intuition about how it works. Ho, Salimans, “Classifier-Free Diffusion Guidance”, NeurIPS workshop on DGMs and Applications”, 2021. ↩ Ramesh, Dhariwal, Nichol, Chu, Chen, “Hierarchical Text-Conditional Image Generation with CLIP Latents”, arXiv, 2022. ↩ Saharia, Chan, Saxena, Li, Whang, Ho, Fleet, Norouzi et al., “Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding”, arXiv, 2022. ↩Diffusion models are autoencoders2022-01-31T00:00:00+00:002022-01-31T00:00:00+00:00https://sander.ai/2022/01/31/diffusion<p>Diffusion models took off like a rocket at the end of 2019, after the publication of Song & Ermon’s <a href="https://arxiv.org/abs/1907.05600">seminal paper</a>. In this blog post, I highlight a connection to another type of model: the venerable autoencoder.</p>
<h2 id="-diffusion-models"><a name="diffusion"></a> Diffusion models</h2>
<figure>
<a href="/images/diffuse2.jpg"><img src="/images/diffuse2.jpg" /></a>
</figure>
<p>Diffusion models are fast becoming the go-to model for any task that requires producing perceptual signals, such as images and sound. They provide similar fidelity as alternatives based on generative adversarial nets (GANs) or autoregressive models, but with much better mode coverage than the former, and a faster and more flexible sampling procedure compared to the latter.</p>
<p>In a nutshell, diffusion models are constructed by first describing a procedure for gradually turning data into noise, and then training a neural network that learns to invert this procedure step-by-step. Each of these steps consists of <strong>taking a noisy input and making it slightly less noisy</strong>, by filling in some of the information obscured by the noise. If you start from pure noise and do this enough times, it turns out you can generate data this way!</p>
<p>Diffusion models have been around for a while<sup id="fnref:equilibrium" role="doc-noteref"><a href="#fn:equilibrium" class="footnote" rel="footnote">1</a></sup>, but really took off at the end of 2019<sup id="fnref:songermon" role="doc-noteref"><a href="#fn:songermon" class="footnote" rel="footnote">2</a></sup>. The ideas are young enough that the field hasn’t really settled on one particular convention or paradigm to describe them, which means almost every paper uses a slightly different framing, and often a different notation as well. This can make it quite challenging to see the bigger picture when trawling through the literature, of which there is already a lot! Diffusion models go by many names: <em>denoising diffusion probabilistic models</em> (DDPMs)<sup id="fnref:ddpm" role="doc-noteref"><a href="#fn:ddpm" class="footnote" rel="footnote">3</a></sup>, <em>score-based generative models</em>, or <em>generative diffusion processes</em>, among others. Some people just call them <em>energy-based models</em> (EBMs), of which they technically are a special case.</p>
<p>My personal favourite perspective starts from the idea of <em>score matching</em><sup id="fnref:scorematching" role="doc-noteref"><a href="#fn:scorematching" class="footnote" rel="footnote">4</a></sup> and uses a formalism based on stochastic differential equations (SDEs)<sup id="fnref:sde" role="doc-noteref"><a href="#fn:sde" class="footnote" rel="footnote">5</a></sup>. For an in-depth treatment of diffusion models from this perspective, I strongly recommend <a href="https://yang-song.github.io/blog/2021/score/">Yang Song’s richly illustrated blog post</a> (which also comes with code and colabs). It is especially enlightening with regards to the connection between all these different perspectives. If you are familiar with variational autoencoders, you may find <a href="https://lilianweng.github.io/lil-log/2021/07/11/diffusion-models.html">Lilian Weng</a> or <a href="https://jmtomczak.github.io/blog/10/10_ddgms_lvm_p2.html">Jakub Tomczak</a>’s takes on this model family more approachable.</p>
<p>If you are curious about generative modelling in general, <a href="https://benanne.github.io/2020/03/24/audio-generation.html#generative-models">section 3 of my blog post</a> on generating music in the waveform domain contains a brief overview of some of the most important concepts and model flavours.</p>
<h2 id="-denoising-autoencoders"><a name="autoencoders"></a> Denoising autoencoders</h2>
<figure>
<a href="/images/bottleneck.jpg"><img src="/images/bottleneck.jpg" /></a>
</figure>
<p>Autoencoders are neural networks that are trained to predict their input. In and of itself, this is a trivial and meaningless task, but it becomes much more interesting when the network architecture is restricted in some way, or when the input is corrupted and the network has to learn to undo this corruption.</p>
<p>A typical architectural restriction is to introduce some sort of <strong>bottleneck</strong>, which limits the amount of information that can pass through. This implies that the network must learn to encode the most important information efficiently to be able to pass it through the bottleneck, in order to be able to accurately reconstruct the input. Such a bottleneck can be created by reducing the capacity of a particular layer of the network, by introducing quantisation (as in VQ-VAEs<sup id="fnref:vqvae" role="doc-noteref"><a href="#fn:vqvae" class="footnote" rel="footnote">6</a></sup>) or by applying some form of regularisation to it during training (as in VAEs<sup id="fnref:vaekingma" role="doc-noteref"><a href="#fn:vaekingma" class="footnote" rel="footnote">7</a></sup> <sup id="fnref:vaerezende" role="doc-noteref"><a href="#fn:vaerezende" class="footnote" rel="footnote">8</a></sup> or contractive autoencoders<sup id="fnref:cae" role="doc-noteref"><a href="#fn:cae" class="footnote" rel="footnote">9</a></sup>). The internal representation used in this bottleneck (often referred to as the <em>latent representation</em>) is what we are really after. <strong>It should capture the essence of the input, while discarding a lot of irrelevant detail.</strong></p>
<p>Corrupting the input is another viable strategy to make autoencoders learn useful representations. One could argue that models with corrupted input are not autoencoders in the strictest sense, because the input and target output differ, but this is really a semantic discussion – one could just as well consider the corruption procedure part of the model itself. In practice, such models are typically referred to as <em>denoising autoencoders</em>.</p>
<p>Denoising autoencoders were actually some of the first true “deep learning” models: back when we hadn’t yet figured out how to reliably train neural networks deeper than a few layers with simple gradient descent, the prevalent approach was to pre-train networks layer by layer, and denoising autoencoders were frequently used for this purpose<sup id="fnref:sdae" role="doc-noteref"><a href="#fn:sdae" class="footnote" rel="footnote">10</a></sup> (especially by Yoshua Bengio and colleagues at MILA – restricted Boltzmann machines were another option, favoured by Geoffrey Hinton and colleagues).</p>
<h2 id="-one-and-the-same"><a name="peas"></a> One and the same?</h2>
<figure>
<a href="/images/spiderman.jpg"><img src="/images/spiderman.jpg" /></a>
</figure>
<p><strong>So what is the link between modern diffusion models and these – by deep learning standards – ancient autoencoders?</strong> I was inspired to ponder this connection a bit more after seeing some recent tweets speculating about autoencoders making a comeback:</p>
<blockquote class="twitter-tweet"><p lang="en" dir="ltr">Are autoencoders making / going to make a comeback?</p>— David Krueger (@DavidSKrueger) <a href="https://twitter.com/DavidSKrueger/status/1428403382293876743?ref_src=twsrc%5Etfw">August 19, 2021</a></blockquote>
<script async="" src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
<blockquote class="twitter-tweet"><p lang="en" dir="ltr">Can you bring autoencoders back by the time my book is out, I'm aiming for 2023</p>— Peli Grietzer (@peligrietzer) <a href="https://twitter.com/peligrietzer/status/1487186529999069186?ref_src=twsrc%5Etfw">January 28, 2022</a></blockquote>
<script async="" src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
<p>As far as I’m concerned, <strong>the autoencoder comeback is already in full swing, it’s just that we call them diffusion models now!</strong> Let’s unpack this.</p>
<p>The neural network that makes diffusion models tick is trained to estimate the so-called <em>score function</em>, \(\nabla_\mathbf{x} \log p(\mathbf{x})\), the gradient of the log-likelihood w.r.t. the input (a vector-valued function): \(\mathbf{s}_\theta (\mathbf{x}) = \nabla_\mathbf{x} \log p_\theta(\mathbf{x})\). Note that this is different from \(\nabla_\theta \log p_\theta(\mathbf{x})\), the gradient w.r.t. the model parameters \(\theta\), which is the one you would use for training if this were a likelihood-based model. The latter tells you how to change the model parameters to increase the likelihood of the input under the model, whereas the former tells you how to <em>change the input itself</em> to increase its likelihood. (This is the same gradient you would use for DeepDream-style manipulation of images.)</p>
<p>In practice, we want to use the same network at every point in the gradual denoising process, i.e. at every noise level (from pure noise all the way to clean data). To account for this, it takes an additional input \(t \in [0, 1]\) which indicates how far along we are in the denoising process: \(\mathbf{s}_\theta (\mathbf{x}_t, t) = \nabla_{\mathbf{x}_t} \log p_\theta(\mathbf{x}_t)\). By convention, \(t = 0\) corresponds to clean data and \(t = 1\) corresponds to pure noise, so we actually “go back in time” when denoising.</p>
<p>The way you train this network is by taking inputs \(\mathbf{x}\) and corrupting them with additive noise \(\mathbf{\varepsilon}_t \sim \mathcal{N}(0, \sigma_t^2)\), and then predicting \(\mathbf{\varepsilon}_t\) from \(\mathbf{x}_t = \mathbf{x} + \mathbf{\varepsilon}_t\). The reason why this works is not entirely obvious. I recommend reading Pascal Vincent’s 2010 tech report on the subject<sup id="fnref:vincent" role="doc-noteref"><a href="#fn:vincent" class="footnote" rel="footnote">11</a></sup> for an in-depth explanation of why you can do this.</p>
<p>Note that the variance depends on \(t\), because it corresponds to the specific noise level at time \(t\). The loss function is typically just mean squared error, sometimes weighted by a scale factor \(\lambda(t)\), so that some noise levels are prioritised over others:</p>
\[\arg\min_\theta \mathcal{L}_\theta = \arg\min_\theta \mathbb{E}_{t,p(\mathbf{x}_t)} \left[\lambda(t) ||\mathbf{s}_\theta (\mathbf{x} + \mathbf{\varepsilon}_t, t) - \mathbf{\varepsilon}_t||_2^2\right] .\]
<p>Going forward, let’s assume \(\lambda(t) \equiv 1\), which is usually what is done in practice anyway (though other choices have their uses as well<sup id="fnref:maxlikelihood" role="doc-noteref"><a href="#fn:maxlikelihood" class="footnote" rel="footnote">12</a></sup>).</p>
<p>One key observation is that <strong>predicting \(\mathbf{\varepsilon}_t\) or \(\mathbf{x}\) are equivalent</strong>, so instead, we could just use</p>
\[\arg\min_\theta \mathbb{E}_{t,p(\mathbf{x}_t)} \left[||\mathbf{s}_\theta' (\mathbf{x} + \mathbf{\varepsilon}_t, t) - \mathbf{x}||_2^2\right] .\]
<p>To see that they are equivalent, consider taking a trained model \(\mathbf{s}_\theta\) that predicts \(\mathbf{\varepsilon}_t\) and add <strong>a new residual connection</strong> to it, going all the way from the input to the output, with a scale factor of \(-1\). This modified model then predicts:</p>
\[\mathbf{\varepsilon}_t - \mathbf{x}_t = \mathbf{\varepsilon}_t - (\mathbf{x} + \mathbf{\varepsilon}_t) = - \mathbf{x} .\]
<p>In other words, we obtain a denoising autoencoder (up to a minus sign). This might seem surprising, but intuitively, it actually makes sense that <strong>to increase the likelihood of a noisy input, you should probably just try to remove the noise, because noise is inherently unpredictable</strong>. Indeed, it turns out that these two things are equivalent.</p>
<h2 id="-a-tenuous-connection"><a name="tenuous"></a> A tenuous connection?</h2>
<figure>
<a href="/images/bridge.jpg"><img src="/images/bridge.jpg" /></a>
</figure>
<p>Of course, the title of this blog post is intentionally a bit facetious: while there is a deeper connection between diffusion models and autoencoders than many people realise, the models have completely different purposes and so are not interchangeable.</p>
<p><strong>There are two key differences</strong> with the denoising autoencoders of yore:</p>
<ul>
<li>the additional input \(t\) makes one single model able to handle <strong>many different noise levels</strong> with a single set of shared parameters;</li>
<li>we care about the output of the model, not the internal latent representation, so there is <strong>no need for a bottleneck</strong>. In fact, it would probably do more harm than good.</li>
</ul>
<p>In the strictest sense, both of these differences have no bearing on whether the model can be considered an autoencoder or not. In practice, however, the point of an autoencoder is usually understood to be to learn a useful latent representation, so saying that diffusion models are autoencoders could perhaps be considered a bit… pedantic. Nevertheless, I wanted to highlight this connection because I think many more people know the ins and outs of autoencoders than diffusion models at this point. I believe that appreciating the link between the two can make the latter less daunting to understand.</p>
<p>This link is not merely a curiosity, by the way; it has also been the subject of several papers, which constitute an <strong>early exploration of the ideas that power modern diffusion models</strong>. Apart from the work by Pascal Vincent mentioned earlier<sup id="fnref:vincent:1" role="doc-noteref"><a href="#fn:vincent" class="footnote" rel="footnote">11</a></sup>, there is also a series of papers by Guillaume Alain and colleagues<sup id="fnref:gyom1" role="doc-noteref"><a href="#fn:gyom1" class="footnote" rel="footnote">13</a></sup> that<sup id="fnref:gyom2" role="doc-noteref"><a href="#fn:gyom2" class="footnote" rel="footnote">14</a></sup> are<sup id="fnref:gyom3" role="doc-noteref"><a href="#fn:gyom3" class="footnote" rel="footnote">15</a></sup> worth<sup id="fnref:gyom4" role="doc-noteref"><a href="#fn:gyom4" class="footnote" rel="footnote">16</a></sup> checking<sup id="fnref:gyom5" role="doc-noteref"><a href="#fn:gyom5" class="footnote" rel="footnote">17</a></sup> out<sup id="fnref:gyom6" role="doc-noteref"><a href="#fn:gyom6" class="footnote" rel="footnote">18</a></sup>!</p>
<p><em>[Note that there is another way to connect diffusion models to autoencoders, by viewing them as (potentially infinitely) deep latent variable models. I am personally less interested in that connection because it doesn’t provide me with much additional insight, but it is just as valid. <a href="https://angusturner.github.io/generative_models/2021/06/29/diffusion-probabilistic-models-I.html">Here’s a blog post by Angus Turner</a> that explores this interpretation in detail.]</em></p>
<h2 id="-noise-and-scale"><a name="scale"></a> Noise and scale</h2>
<figure>
<a href="/images/noisy_mountains.jpg"><img src="/images/noisy_mountains.jpg" alt="A noisy image of a mountain range, with the level of noise gradually decreasing from left to right." /></a>
</figure>
<p>I believe the idea of training a <strong>single model to handle many different noise levels with shared parameters</strong> is ultimately the key ingredient that made diffusion models really take off. Song & Ermon<sup id="fnref:songermon:1" role="doc-noteref"><a href="#fn:songermon" class="footnote" rel="footnote">2</a></sup> called them <em>noise-conditional score networks</em> (NCSNs) and provide a very lucid explanation of why this is important, which I won’t repeat here.</p>
<p>The idea of using different noise levels in a single denoising autoencoder had previously been explored for representation learning, but not for generative modelling. Several works suggest gradually decreasing the level of noise over the course of training to improve the learnt representations<sup id="fnref:geras1" role="doc-noteref"><a href="#fn:geras1" class="footnote" rel="footnote">19</a></sup> <sup id="fnref:chandra" role="doc-noteref"><a href="#fn:chandra" class="footnote" rel="footnote">20</a></sup> <sup id="fnref:zhang" role="doc-noteref"><a href="#fn:zhang" class="footnote" rel="footnote">21</a></sup>. Composite denoising autoencoders<sup id="fnref:geras2" role="doc-noteref"><a href="#fn:geras2" class="footnote" rel="footnote">22</a></sup> have multiple subnetworks that handle different noise levels, which is a step closer to the score networks that we use in diffusion models, though still missing the parameter sharing.</p>
<p>A particularly interesting observation stemming from these works, which is also highly relevant to diffusion models, is that <strong>representations learnt using different noise levels tend to correspond to different scales of features</strong>: the higher the noise level, the larger-scale the features that are captured. I think this connection is worth investigating further: it implies that diffusion models fill in missing parts of the input at progressively smaller scales, as the noise level decreases step by step. This does seem to be the case in practice, and it is potentially a useful feature. Concretely, it means that \(\lambda(t)\) can be designed to prioritise the modelling of particular feature scales! This is great, because excessive attention to detail is actually a major problem with likelihood-based models (I’ve previously discussed this in more detail in <a href="https://benanne.github.io/2020/09/01/typicality.html#right-level">section 6 of my blog post about typicality</a>).</p>
<p>This connection between noise levels and feature scales was initially baffling to me: the noise \(\mathbf{\varepsilon}_t\) that we add to the input during training is isotropic Gaussian, so <strong>we are effectively adding noise to each input element (e.g. pixel) independently</strong>. If that is the case, <strong>how can the level of noise (i.e. the variance) possibly impact the scale of the features that are learnt?</strong> I found it helpful to think of it this way:</p>
<ul>
<li>Let’s say we are working with images. Each pixel in an image that could be part of a particular feature (e.g. a human face) provides <strong>evidence for the presence (or absence) of that feature</strong>.</li>
<li>When looking at an image, <strong>we implicitly aggregate the evidence</strong> provided by all the pixels to determine which features are present (e.g. whether there is a face in the image or not).</li>
<li>Larger-scale features in the image will cover a larger proportion of pixels. Therefore, <strong>if a larger-scale feature is present</strong> in an image, there is <strong>more evidence</strong> pointing towards that feature.</li>
<li>Even if we add noise with a very high variance, that evidence will still be apparent, because <strong>when combining information from all pixels, we average out the noise</strong>.</li>
<li>If more pixels are involved in this process, the tolerable noise level increases, because the maximal variance that still allows for the noise to be canceled out is much higher. For smaller-scale features however, recovery will be impossible because the noise dominates when we can only aggregate information from a smaller set of pixels.</li>
</ul>
<p>Concretely, if an image contains a human face and we add a lot of noise to it, we will probably no longer be able to discern the face if it is far away from the camera (i.e. covers fewer pixels in the image), whereas if it is close to the camera, we might still see a faint outline. The header image of this section provides another example: the level of noise decreases from left to right. On the very left, we can still see the rough outline of a mountain despite very high levels of noise.</p>
<p>This is completely handwavy, but it provides some intuition for why there is a direct correspondence between the variance of the noise and the scale of features captured by denoising autoencoders and score networks.</p>
<h2 id="-closing-thoughts"><a name="thoughts"></a> Closing thoughts</h2>
<figure>
<a href="/images/sunset.jpg"><img src="/images/sunset.jpg" /></a>
</figure>
<p>So there you have it: <strong>diffusion models are autoencoders. Sort of. When you squint a bit.</strong> Here are some key takeaways, to wrap up:</p>
<ul>
<li>Learning to predict the score function \(\nabla_\mathbf{x} \log p(\mathbf{x})\) of a distribution can be achieved by learning to denoise examples of that distribution. This is a core underlying idea that powers modern diffusion models.</li>
<li>Compared to denoising autoencoders, score networks in diffusion models can handle all noise levels with a single set of parameters, and do not have bottlenecks. But other than that, they do the same thing.</li>
<li>Noise levels and feature scales are closely linked: high noise levels lead to models capturing large-scale features, low noise levels lead to models focusing on fine-grained features.</li>
</ul>
<p><em>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</em></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@misc{dieleman2022diffusion,
author = {Dieleman, Sander},
title = {Diffusion models are autoencoders},
url = {https://benanne.github.io/2022/01/31/diffusion.html},
year = {2022}
}
</code></pre></div></div>
<h2 id="-acknowledgements"><a name="acknowledgements"></a> Acknowledgements</h2>
<p>Thanks to Conor Durkan and Katie Millican for fruitful discussions!</p>
<h2 id="-references"><a name="references"></a> References</h2>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:equilibrium" role="doc-endnote">
<p>Sohl-Dickstein, Weiss, Maheswaranathan and Ganguli, “<a href="https://arxiv.org/abs/1503.03585">Deep Unsupervised Learning using Nonequilibrium Thermodynamics</a>”, International Conference on Machine Learning, 2015. <a href="#fnref:equilibrium" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:songermon" role="doc-endnote">
<p>Song and Ermon, “<a href="https://arxiv.org/abs/1907.05600">Generative Modeling by Estimating Gradients of the Data Distribution</a>”, Neural Information Processing Systems, 2019. <a href="#fnref:songermon" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:songermon:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:ddpm" role="doc-endnote">
<p>Ho, Jain and Abbeel, “<a href="https://arxiv.org/abs/2006.11239">Denoising Diffusion Probabilistic Models</a>”, Neural Information Processing Systems, 2020. <a href="#fnref:ddpm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:scorematching" role="doc-endnote">
<p>Hyvarinen, “<a href="https://www.jmlr.org/papers/v6/hyvarinen05a.html">Estimation of Non-Normalized Statistical Models by Score Matching</a>”, Journal of Machine Learning Research, 2005. <a href="#fnref:scorematching" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sde" role="doc-endnote">
<p>Song, Sohl-Dickstein, Kingma, Kumar, Ermon and Poole, “<a href="https://arxiv.org/abs/2011.13456">Score-Based Generative Modeling through Stochastic Differential Equations</a>”, International Conference on Learning Representations, 2021. <a href="#fnref:sde" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqvae" role="doc-endnote">
<p>van den Oord, Vinyals and Kavukcuoglu, “<a href="https://arxiv.org/abs/1711.00937">Neural Discrete Representation Learning</a>”, Neural Information Processing Systems, 2017. <a href="#fnref:vqvae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vaekingma" role="doc-endnote">
<p>Kingma and Welling, “<a href="https://arxiv.org/abs/1312.6114">Auto-Encoding Variational Bayes</a>”, International Conference on Learning Representations, 2014. <a href="#fnref:vaekingma" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vaerezende" role="doc-endnote">
<p>Rezende, Mohamed and Wierstra, “<a href="https://arxiv.org/abs/1401.4082">Stochastic Backpropagation and Approximate Inference in Deep Generative Models</a>”, International Conference on Machine Learning, 2014. <a href="#fnref:vaerezende" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:cae" role="doc-endnote">
<p>Rifai, Vincent, Muller, Glorot and Bengio, “<a href="https://openreview.net/forum?id=HkZN5j-dZH">Contractive Auto-Encoders: Explicit Invariance During Feature Extraction</a>”, International Conference on Machine Learning, 2011. <a href="#fnref:cae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sdae" role="doc-endnote">
<p>Vincent, Larochelle, Lajoie, Bengio and Manzagol, “<a href="https://www.jmlr.org/papers/v11/vincent10a.html">Stacked Denoising Autoencoders: Learning Useful Representations in a Deep Network with a Local Denoising Criterion</a>”, Journal of Machine Learning Research, 2010. <a href="#fnref:sdae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vincent" role="doc-endnote">
<p>Vincent, “<a href="http://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf">A Connection Between Score Matching and Denoising Autoencoders</a>”, Technical report, 2010. <a href="#fnref:vincent" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:vincent:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:maxlikelihood" role="doc-endnote">
<p>Song, Durkan, Murray and Ermon, “<a href="https://arxiv.org/abs/2101.09258">Maximum Likelihood Training of Score-Based Diffusion Models</a>”, Neural Information Processing Systems, 2021. <a href="#fnref:maxlikelihood" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gyom1" role="doc-endnote">
<p>Bengio, Alain and Rifai, “<a href="https://arxiv.org/abs/1207.0057">Implicit density estimation by local moment matching to sample from auto-encoders</a>”, arXiv, 2012. <a href="#fnref:gyom1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gyom2" role="doc-endnote">
<p>Alain, Bengio and Rifai, “<a href="http://www.eng.uwaterloo.ca/~jbergstr/files/nips_dl_2012/Paper%2029.pdf">Regularized auto-encoders estimate local statistics</a>”, Neural Information Processing Systems, Deep Learning workshop, 2012. <a href="#fnref:gyom2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gyom3" role="doc-endnote">
<p>Bengio,Yao, Alain and Vincent, “<a href="https://arxiv.org/abs/1305.6663">Generalized denoising auto-encoders as generative models</a>”, Neural Information Processing Systems, 2013. <a href="#fnref:gyom3" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gyom4" role="doc-endnote">
<p>Alain and Bengio, “<a href="https://jmlr.org/papers/volume15/alain14a/alain14a.pdf">What regularized auto-encoders learn from the data-generating distribution</a>”, Journal of Machine Learning Research, 2014. <a href="#fnref:gyom4" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gyom5" role="doc-endnote">
<p>Bengio, Laufer, Alain and Yosinski, “<a href="http://proceedings.mlr.press/v32/bengio14.pdf">Deep generative stochastic networks trainable by backprop</a>”, International Conference on Machine Learning, 2014. <a href="#fnref:gyom5" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gyom6" role="doc-endnote">
<p>Alain, Bengio, Yao, Yosinski, Laufer, Zhang and Vincent, “<a href="https://arxiv.org/abs/1503.05571">GSNs: generative stochastic networks</a>”, Information and Inference, 2016. <a href="#fnref:gyom6" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:geras1" role="doc-endnote">
<p>Geras and Sutton, “<a href="https://arxiv.org/abs/1406.3269">Scheduled denoising autoencoders</a>”, International Conference on Learning Representations, 2015. <a href="#fnref:geras1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:chandra" role="doc-endnote">
<p>Chandra and Sharma, “<a href="https://link.springer.com/chapter/10.1007/978-3-319-12637-1_67">Adaptive noise schedule for denoising autoencoder</a>”, Neural Information Processing Systems, 2014. <a href="#fnref:chandra" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:zhang" role="doc-endnote">
<p>Zhang and Zhang, “<a href="https://dl.acm.org/doi/abs/10.1007/s11704-016-6107-0">Convolutional adaptive denoising autoencoders for hierarchical feature extraction</a>”, Frontiers of Computer Science, 2018. <a href="#fnref:zhang" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:geras2" role="doc-endnote">
<p>Geras and Sutton, “<a href="https://link.springer.com/chapter/10.1007/978-3-319-46128-1_43">Composite denoising autoencoders</a>”, Joint European Conference on Machine Learning and Knowledge Discovery in Databases, 2016. <a href="#fnref:geras2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Diffusion models took off like a rocket at the end of 2019, after the publication of Song & Ermon’s seminal paper. In this blog post, I highlight a connection to another type of model: the venerable autoencoder.Addendum: quantifying our flawed intuitions2020-09-01T00:00:00+01:002020-09-01T00:00:00+01:00https://sander.ai/2020/09/01/typicality-addendum<p>This post is an addendum to <a href="/2020/09/01/typicality.html">my blog post about typicality</a>. Please consider reading that first, if you haven’t already. Here, I will try to quantify what happens when our intuitions fail us in high-dimensional spaces.</p>
<p><em>Note that the practical relevance of this is limited, so consider this a piece of optional extra content!</em></p>
<p>In the ‘unfair coin flips’ example from the main blog post, it’s actually pretty clear what happens when our intuitions fail us: we think of the binomial distribution, <strong>ignoring the order of the sequences as a factor, when we should actually be taking it into account</strong>. Referring back to the table from section 2.1, we use the probabilities in the rightmost column, when we should be using those in the third column. But when we think of a high-dimensional Gaussian distribution and come to the wrong conclusion, what distribution are we <em>actually</em> thinking of?</p>
<h2 id="the-gaussian-distribution-mathcaln_k">The Gaussian distribution \(\mathcal{N}_K\)</h2>
<figure>
<img src="/images/bubbles.jpg" />
</figure>
<p>Let’s start by quantifying what a multivariate Gaussian distribution actually looks like: let \(\mathbf{x} \sim \mathcal{N}(\mathbf{0}, I_K)\), a standard Gaussian distribution in \(K\) dimensions, henceforth referred to as \(\mathcal{N}_K\). We can sample from it by drawing \(K\) independent one-dimensional samples \(x_i \sim \mathcal{N}(0, 1)\), and joining them into a vector \(\mathbf{x}\). This distribution is <strong>spherically symmetric</strong>, which makes it very natural to think about samples in terms of their <strong>distance to the mode</strong> (in this case, the origin, corresponding to the zero-vector \(\mathbf{0}\)), because all samples at a given distance \(r\) have the same density.</p>
<p>Now, let’s look at the distribution of \(r\): it seems as if the multivariate Gaussian distribution \(\mathcal{N}_K\) naturally arises by taking a univariate version of it, and rotating it around the mode in every possible direction in \(K\)-dimensional space. Because each of these individual rotated copies is Gaussian, this in turn might seem to imply that the distance from the mode \(r\) is itself Gaussian (or rather half-Gaussian, since it is a nonnegative quantity). But this is incorrect! \(r\) actually follows a <a href="https://en.wikipedia.org/wiki/Chi_distribution"><strong>chi distribution</strong></a> with \(K\) degrees of freedom: \(r \sim \chi_K\).</p>
<p>Note that for \(K = 1\), this does indeed correspond to a half-Gaussian distribution. But as \(K\) increases, the mode of the chi distribution rapidly shifts away from 0: it actually sits at \(\sqrt{K - 1}\). This leaves considerably less probability mass near 0, where the mode of our original multivariate Gaussian \(\mathcal{N}_K\) is located.</p>
<p>This exercise yields an alternative sampling strategy for multivariate Gaussians: first, sample a distance from the mode \(r \sim \chi_K\). Then, sample a direction, i.e. a vector on the \(K\)-dimensional unit sphere \(S^K\), uniformly at random: \(\mathbf{\theta} \sim U[S^K]\). Multiply them together to obtain a Gaussian sample: \(\mathbf{x} = r \cdot \mathbf{\theta} \sim \mathcal{N}_K\).</p>
<h2 id="the-gaussian-mirage-distribution-mathcalm_k">The Gaussian mirage distribution \(\mathcal{M}_K\)</h2>
<figure>
<img src="/images/mirage.jpg" />
</figure>
<p>What if, instead of sampling \(r \sim \chi_K\), we sampled \(r \sim \mathcal{N}(0, K)\) instead? Note that \(\sigma^2_{\chi_K} = K\), so this change preserves the scale of the resulting vectors. For \(K = 1\), we get the same distribution for \(\mathbf{x}\), but for \(K > 1\), we get something very different. The resulting distribution represents what we might think the multivariate Gaussian distribution looks like, if we rely on a mistaken intuition and squint a bit. Let’s call this the <strong>Gaussian mirage</strong> distribution, denoted by \(\mathcal{M}\): \(\mathbf{x} = r \cdot \mathbf{\theta} \sim \mathcal{M}_K\). (If this thing already has a name, I’m not aware of it, so please let me know!)</p>
<p>We’ve already established that \(\mathcal{M}_1 \equiv \mathcal{N}_1\). But in higher dimensions, these distributions behave very differently. One way to comprehend this is to look at a flattened histogram of samples across all coordinates:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="k">def</span> <span class="nf">gaussian</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">mirage</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">):</span>
<span class="n">direction</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span>
<span class="n">direction</span> <span class="o">/=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">direction</span><span class="o">**</span><span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
<span class="n">distance</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">k</span><span class="p">),</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="n">distance</span> <span class="o">*</span> <span class="n">direction</span>
<span class="k">def</span> <span class="nf">plot_histogram</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">plt</span><span class="p">.</span><span class="n">hist</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">ravel</span><span class="p">(),</span> <span class="n">bins</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">80000</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlim</span><span class="p">(</span><span class="o">-</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tick_params</span><span class="p">(</span><span class="n">labelleft</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">left</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">labelbottom</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">bottom</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ks</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">100</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">ks</span><span class="p">):</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">ks</span><span class="p">),</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="sa">f</span><span class="s">'K = </span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
<span class="n">plot_histogram</span><span class="p">(</span><span class="n">gaussian</span><span class="p">(</span><span class="mi">10</span><span class="o">**</span><span class="mi">6</span> <span class="o">//</span> <span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">ks</span><span class="p">),</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">+</span> <span class="nb">len</span><span class="p">(</span><span class="n">ks</span><span class="p">))</span>
<span class="n">plot_histogram</span><span class="p">(</span><span class="n">mirage</span><span class="p">(</span><span class="mi">10</span><span class="o">**</span><span class="mi">6</span> <span class="o">//</span> <span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span>
</code></pre></div></div>
<figure>
<a href="/images/gaussian_histograms.png"><img src="/images/gaussian_histograms.png" alt="Histograms of the flattened coordinates of the multivariate Gaussian distribution (top) and the Gaussian mirage (bottom)." /></a>
<figcaption>Histograms of the flattened coordinates of the multivariate Gaussian distribution (top) and the Gaussian mirage (bottom), for different dimensionalities (K). For the mirage, the histograms become increasingly peaked around 0 as the dimensionality increases.</figcaption>
</figure>
<p>For \(\mathcal{N}_K\), this predictably looks like a univariate Gaussian for all \(K\). For \(\mathcal{M}_K\), it becomes highly <a href="https://en.wikipedia.org/wiki/Kurtosis">leptokurtic</a> as \(K\) increases, indicating that <strong>dramatically more probability mass is located close to the mode</strong>.</p>
<h2 id="typical-sets-of-mathcaln_k-and-mathcalm_k">Typical sets of \(\mathcal{N}_K\) and \(\mathcal{M}_K\)</h2>
<p>Let’s also look at the typical sets for both of these distributions. For \(\mathcal{N}_K\), the probability density function (pdf) has the form:</p>
\[f_{\mathcal{N}_K}(\mathbf{x}) = (2 \pi)^{-\frac{K}{2}} \exp \left( -\frac{\mathbf{x}^T \mathbf{x}}{2} \right),\]
<p>and the differential entropy is given by:</p>
\[H_{\mathcal{N}_K} = \frac{K}{2} \log \left(2 \pi e \right) .\]
<p>To find the typical set, we just need to look for the \(\mathbf{x}\) where \(f_{\mathcal{N}_K}(\mathbf{x}) \approx 2^{-H_{\mathcal{N}_K}} = (2 \pi e)^{-\frac{K}{2}}\) (assuming the entropy is measured in bits). This is clearly the case when \(\mathbf{x}^T\mathbf{x} \approx K\), or in other words, for <strong>any \(\mathbf{x}\) whose distance from the mode is close to \(\sqrt{K}\)</strong>. This is the <em>Gaussian annulus</em> from before.</p>
<p>Let’s subject the Gaussian mirage \(\mathcal{M}_K\) to the same treatment. It’s not obvious how to express the pdf in terms of \(\mathbf{x}\), but it’s easier if we rewrite \(\mathbf{x}\) as \(r \cdot \mathbf{\theta}\), as before, and imagine the sampling procedure: first, pick a radius \(r \sim \mathcal{HN}(0, K)\) (the half-Gaussian distribution — using the Gaussian distribution complicates the math a bit, because the radius should be nonnegative), and then pick a position on the \(K\)-sphere with radius \(r\), uniformly at random:</p>
\[f_{\mathcal{M}_K}(\mathbf{x}) = f_{\mathcal{HN}(0, K)}(r) \cdot f_{U[S^K(r)]}(\theta) = \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \cdot \frac{1}{r^{K-1}} \frac{\Gamma\left( \frac{K}{2} \right)}{2 \pi ^ \frac{K}{2}} .\]
<p>The former factor is the density of the half-Gaussian distribution: note the additional factor 2 compared to the standard Gaussian density, because we only consider nonnegative values of \(r\). The latter is the density of a uniform distribution on the \(K\)-sphere with radius \(r\) (which is the inverse of its surface area). As an aside, this factor is worth taking a closer look at, because it behaves in a rather peculiar way. Here’s the surface area of a unit \(K\)-sphere for increasing \(K\):</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">scipy.special</span>
<span class="n">K</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">30</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span><span class="o">**</span><span class="p">(</span><span class="n">K</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">))</span> <span class="o">/</span> <span class="n">scipy</span><span class="p">.</span><span class="n">special</span><span class="p">.</span><span class="n">gamma</span><span class="p">(</span><span class="n">K</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">stem</span><span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">basefmt</span><span class="o">=</span><span class="s">' '</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">35</span><span class="p">)</span>
</code></pre></div></div>
<figure>
<a href="/images/sphere_area.png"><img src="/images/sphere_area.png" alt="Surface area of a K-dimensional unit sphere, for K ranging from 0 to 30." /></a>
<figcaption>Surface area of a K-dimensional unit sphere, for K ranging from 0 to 30.</figcaption>
</figure>
<p>Confused? You and me both! Believe it or not, <strong>the surface area of a \(K\)-sphere tends to zero with increasing \(K\)</strong> — but only after growing to a maximum at \(K = 7\) first. <a href="https://math.stackexchange.com/questions/67039/why-does-the-volume-of-the-unit-sphere-go-to-zero">High-dimensional spaces are <em>weird</em></a>.</p>
<p>Another thing worth noting is that the density at the mode \(f_{\mathcal{M}_K}(\mathbf{0}) = +\infty\) for \(K > 1\), which already suggests that this distribution has a lot of its mass concentrated near the mode.</p>
<p>Computing the entropy of this distribution takes a bit of work. The differential entropy is:</p>
\[H_{\mathcal{M}_K} = - \int_{\mathbb{R}^K} f_{\mathcal{M}_K}(\mathbf{x}) \log f_{\mathcal{M}_K}(\mathbf{x}) \mathrm{d}\mathbf{x} .\]
<p>We can use the radial symmetry of this density to reformulate this as an integral of a scalar function:</p>
\[H_{\mathcal{M}_K} = - \int_0^{+\infty} f_{\mathcal{M}_K}(r) \log f_{\mathcal{M}_K}(r) S^K(r) \mathrm{d} r,\]
<p>where \(S^K(r)\) is the surface area of a \(K\)-sphere with radius \(r\). Filling in the density function, we get:</p>
\[H_{\mathcal{M}_K} = - \int_0^{+\infty} \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \cdot \log \left( \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \cdot \frac{1}{r^{K-1}} \frac{\Gamma\left( \frac{K}{2} \right)}{2 \pi ^ \frac{K}{2}} \right) \mathrm{d} r,\]
<p>where we have made use of the fact that \(S^K(r)\) cancels out with the second factor of \(f_{\mathcal{M}_K}(r)\). We can split up the \(\log\) into three different terms, \(H_{\mathcal{M}_K} = H_1 + H_2 + H_3\):</p>
\[H_1 = - \int_0^{+\infty} \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \left(-\frac{r^2}{2 K} \right) \mathrm{d} r = \int_0^{+\infty} \frac{r^2}{\sqrt{2 \pi}} \exp \left( -\frac{r^2}{2} \right) \mathrm{d} r = \frac{1}{2},\]
\[H_2 = - \int_0^{+\infty} \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \log \left( \frac{1}{r^{K-1}} \right) \mathrm{d} r = \frac{K - 1}{2} \left( \log \frac{K}{2} - \gamma \right),\]
\[H_3 = - \int_0^{+\infty} \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \log \left( \frac{2}{\sqrt{2 \pi K}} \frac{\Gamma\left( \frac{K}{2} \right)}{2 \pi ^ \frac{K}{2}} \right) \mathrm{d} r = - \log \left( \frac{1}{\sqrt{2 \pi K}} \frac{\Gamma\left( \frac{K}{2} \right)}{\pi ^ \frac{K}{2}} \right),\]
<p>where we have taken \(\log\) to be the natural logarithm for convenience, and \(\gamma\) is the <a href="https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant">Euler-Mascheroni constant</a>. In summary:</p>
\[H_{\mathcal{M}_K} = \frac{1}{2} + \frac{K - 1}{2} \left( \log \frac{K}{2} - \gamma \right) - \log \left( \frac{1}{\sqrt{2 \pi K}} \frac{\Gamma\left( \frac{K}{2} \right)}{\pi ^ \frac{K}{2}} \right) .\]
<p>Note that \(H_{\mathcal{M}_1} = \frac{1}{2} \log (2 \pi e)\), matching the standard Gaussian distribution as expected.</p>
<p>Because this is measured in nats, not in bits, we find the typical set where \(f_{\mathcal{M}_K}(\mathbf{x}) \approx \exp(-H_{\mathcal{M}_K})\). We must find \(r \geq 0\) so that</p>
\[\frac{r^2}{2 K} + (K - 1) \log r = \frac{1}{2} + \frac{K - 1}{2} \left( \log \frac{K}{2} - \gamma \right) .\]
<p>We can express the solution of this equation in terms of the Lambert \(W\) function:</p>
\[r = \sqrt{K (K - 1) W\left(\frac{1}{K (K - 1)} \exp \left( \frac{1}{K - 1} + \log \frac{K}{2} - \gamma \right) \right)} .\]
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">scipy.special</span>
<span class="n">K</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">unique</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">round</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">logspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">100</span><span class="p">)))</span>
<span class="n">w_arg</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">K</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">K</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="n">euler_gamma</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">K</span> <span class="o">*</span> <span class="p">(</span><span class="n">K</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">r</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">K</span> <span class="o">*</span> <span class="p">(</span><span class="n">K</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">scipy</span><span class="p">.</span><span class="n">special</span><span class="p">.</span><span class="n">lambertw</span><span class="p">(</span><span class="n">w_arg</span><span class="p">))</span>
<span class="n">r</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># Special case for K = 1.
</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="n">r</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">K</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xscale</span><span class="p">(</span><span class="s">'log'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'$K$'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'$</span><span class="se">\\</span><span class="s">frac{r}{</span><span class="se">\\</span><span class="s">sqrt{K}}$'</span><span class="p">)</span>
</code></pre></div></div>
<figure>
<a href="/images/mirage_radius.png"><img src="/images/mirage_radius.png" alt="The distance from the mode at which the typical set of the Gaussian mirage is found, as a function of K." /></a>
<figcaption>The distance from the mode at which the typical set of the Gaussian mirage is found, normalised by the standard deviation, as a function of K.</figcaption>
</figure>
<p>As \(K \to +\infty\), this seems to converge to the value \(0.52984 \sqrt{K}\), which is somewhere in between the mode (\(0\)) and the mean (\(\sqrt{\frac{2K}{\pi}} \approx 0.79788 \sqrt{K}\)) of the half-Gaussian distribution (which \(r\) follows by construction). This is not just an interesting curiosity: although it is clear that the typical set of \(\mathcal{M}_K\) is much closer to the mode than for \(\mathcal{N}_K\) (because \(r < \sqrt{K}\)), the mode is not unequivocally a member of the typical set. In fact, the definition of typical sets sort of breaks down for this distribution, because we need to allow for a very large range of probability densities to capture the bulk of its mass. In this sense, it behaves a lot more like the one-dimensional Gaussian. Nevertheless, even this strange concoction of a distribution exhibits unintuitive behaviour in high-dimensional space!</p>
<p><em>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</em></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@misc{dieleman2020typicality,
author = {Dieleman, Sander},
title = {Musings on typicality},
url = {https://benanne.github.io/2020/09/01/typicality.html},
year = {2020}
}
</code></pre></div></div>This post is an addendum to my blog post about typicality. Please consider reading that first, if you haven’t already. Here, I will try to quantify what happens when our intuitions fail us in high-dimensional spaces.Musings on typicality2020-09-01T00:00:00+01:002020-09-01T00:00:00+01:00https://sander.ai/2020/09/01/typicality<p>If you’re training or sampling from generative models, <strong>typicality</strong> is a concept worth understanding. It sheds light on why beam search doesn’t work for autoregressive models of images, audio and video; why you can’t just threshold the likelihood to perform anomaly detection with generative models; and why high-dimensional Gaussians are “soap bubbles”. This post is a summary of my current thoughts on the topic.</p>
<p>First, some context: one of the reasons I’m writing this, is to structure my own thoughts about typicality and the unintuitive behaviour of high-dimensional probability distributions. Most of these thoughts have not been empirically validated, and several are <strong>highly speculative</strong> and could be wrong. Please bear this in mind when reading, and don’t hesitate to use the comments section to correct me. Another reason is to draw more attention to the concept, as I’ve personally found it extremely useful to gain insight into the behaviour of generative models, and to correct some of my flawed intuitions. I <a href="https://twitter.com/sedielem/status/1264587646321516544">tweeted</a> about typicality a few months ago, but as it turns out, I have a lot more to say on the topic!</p>
<p>As with most of my blog posts, I will assume a degree of familiarity with machine learning. For certain parts, some knowledge of generative modelling is probably useful as well. <a href="https://benanne.github.io/2020/03/24/audio-generation.html#generative-models">Section 3 of my previous blog post</a> provides an overview of generative models.</p>
<p><strong>Overview</strong> (click to scroll to each section):</p>
<ol>
<li><em><a href="#likelihood">The joys of likelihood</a></em></li>
<li><em><a href="#examples">Motivating examples</a></em></li>
<li><em><a href="#abstraction">Abstraction and the curse of dimensionality</a></em></li>
<li><em><a href="#typicality">Typicality</a></em></li>
<li><em><a href="#in-the-wild">Typicality in the wild</a></em></li>
<li><em><a href="#right-level">The right level of abstraction</a></em></li>
<li><em><a href="#closing-thoughts">Closing thoughts</a></em></li>
<li><em><a href="#acknowledgements">Acknowledgements</a></em></li>
<li><em><a href="#references">References</a></em></li>
</ol>
<h2 id="-the-joys-of-likelihood"><a name="likelihood"></a> The joys of likelihood</h2>
<p>When it comes to generative modelling, my personal preference for the <strong>likelihood-based paradigm</strong> is no secret (my recent foray into <a href="https://www.deepmind.com/publications/end-to-end-adversarial-text-to-speech">adversarial methods for text-to-speech</a> notwithstanding). While there are many other ways to build and train models (e.g. using adversarial networks, score matching, optimal transport, quantile regression, … see <a href="https://benanne.github.io/2020/03/24/audio-generation.html#generative-models">my previous blog post</a> for an overview), there is something intellectually pleasing about the simplicity of maximum likelihood training: the model explicitly parameterises a probability distribution, and we fit the parameters of that distribution so it is able to explain the observed data as well as possible (i.e., assigns to it the highest possible likelihood).</p>
<p>It turns out that this is far from the whole story, and <strong>‘<em>higher likelihood</em>’ doesn’t always mean <em>better</em> in a way that we actually care about</strong>. In fact, the way likelihood behaves in relation to the quality of a model as measured by humans (e.g. by inspecting samples) can be deeply unintuitive. This has been well-known in the machine learning community for some time, and Theis et al.’s <a href="https://arxiv.org/abs/1511.01844"><em>A note on the evaluation of generative models</em></a><sup id="fnref:anote" role="doc-noteref"><a href="#fn:anote" class="footnote" rel="footnote">1</a></sup> does an excellent job of demonstrating this with clever thought experiments and concrete examples. In what follows, I will expound on what I think is going on when likelihoods disagree with our intuitions.</p>
<p>One particular way in which a higher likelihood can correspond to a worse model is through <strong>overfitting</strong> on the training set. Because overfitting is ubiquitous in machine learning research, the unintuitive behaviours of likelihood are often incorrectly ascribed to this phenomenon. In this post, I will assume that overfitting is not an issue, and that we are talking about properly regularised models trained on large enough datasets.</p>
<h2 id="-motivating-examples"><a name="examples"></a> Motivating examples</h2>
<h3 id="unfair-coin-flips">Unfair coin flips</h3>
<figure>
<img src="/images/coins.jpg" />
</figure>
<p><a href="https://www.jessicayung.com/counterintuitive-probabilities-typical-sets-from-information-theory/">Jessica Yung has a great blog post</a> that demonstrates how even the simplest of probability distributions start behaving in unintuitive ways in higher-dimensional spaces, and she links this to the concept of typicality. I will borrow her example here and expand on it a bit, but I recommend reading the original post.</p>
<p>To summarise: suppose you have an unfair coin that lands on heads 3 times out of 4. If you toss this coin 16 times, you would expect to see 12 heads (<code class="language-plaintext highlighter-rouge">H</code>) and 4 tails (<code class="language-plaintext highlighter-rouge">T</code>) on average. Of course you wouldn’t expect to see exactly 12 heads and 4 tails every time: there’s a pretty good chance you’d see 13 heads and 3 tails, or 11 heads and 5 tails. Seeing 16 heads and no tails would be quite surprising, but it’s not implausible: in fact, it will happen about 1% of the time. Seeing all tails seems like it would be a miracle. Nevertheless, each coin toss is independent, so even this has a non-zero probability of being observed.</p>
<p>When we count the number of heads and tails in the observed sequence, we’re looking at the <strong><a href="https://en.wikipedia.org/wiki/Binomial_distribution">binomial distribution</a></strong>. We’ve made the implicit assumption that what we care about is the <strong>frequency of occurrence of both outcomes, and not the order in which they occur</strong>. We’ve made <em>abstraction</em> of the order, and we are effectively treating the sequences as unordered sets, so that <code class="language-plaintext highlighter-rouge">HTHHTHHHHTTHHHHH</code> and <code class="language-plaintext highlighter-rouge">HHHHHTHTHHHTHTHH</code> are basically the same thing. That is often desirable, but it’s worth being aware of such assumptions, and making them explicit.</p>
<p><strong>If we do not ignore the order, and ask which sequence is the most likely, the answer is ‘all heads’.</strong> That may seem surprising at first, because seeing only heads is a relatively rare occurrence. But note that we’re asking a different question here, about the ordered sequences themselves, rather than about their statistics. While the difference is pretty clear here, the implicit assumptions and abstractions that we tend to use in our reasoning are often more subtle.</p>
<p>The table and figure below show how the probability of observing a given number of heads and tails can be found by multiplying the probability of a particular sequence with the number of such sequences. Note that ‘all heads’ has the highest probability out of all sequences (bolded), but there is only a single such sequence. The most likely number of heads we’ll observe is 12 (also bolded): even though each individual sequence with 12 heads is less likely, there are a lot more of them, and this second factor ends up dominating.</p>
<table>
<thead>
<tr>
<th style="text-align: center">#H</th>
<th style="text-align: center">#T</th>
<th style="text-align: center">p(sequence)</th>
<th style="text-align: center"># sequences</th>
<th style="text-align: center">p(#H, #T)</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: center">0</td>
<td style="text-align: center">16</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^0 \left(\frac{1}{4}\right)^{16} = 2.33 \cdot 10^{-10}\)</td>
<td style="text-align: center">1</td>
<td style="text-align: center">\(2.33\cdot 10^{-10}\)</td>
</tr>
<tr>
<td style="text-align: center">1</td>
<td style="text-align: center">15</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^1 \left(\frac{1}{4}\right)^{15} = 6.98 \cdot 10^{-10}\)</td>
<td style="text-align: center">16</td>
<td style="text-align: center">\(1.12\cdot 10^{-8}\)</td>
</tr>
<tr>
<td style="text-align: center">2</td>
<td style="text-align: center">14</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^2 \left(\frac{1}{4}\right)^{14} = 2.10 \cdot 10^{-9}\)</td>
<td style="text-align: center">120</td>
<td style="text-align: center">\(2.51\cdot 10^{-7}\)</td>
</tr>
<tr>
<td style="text-align: center">3</td>
<td style="text-align: center">13</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^3 \left(\frac{1}{4}\right)^{13} = 6.29 \cdot 10^{-9}\)</td>
<td style="text-align: center">560</td>
<td style="text-align: center">\(3.52\cdot 10^{-6}\)</td>
</tr>
<tr>
<td style="text-align: center">4</td>
<td style="text-align: center">12</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^4 \left(\frac{1}{4}\right)^{12} = 1.89 \cdot 10^{-8}\)</td>
<td style="text-align: center">1820</td>
<td style="text-align: center">\(3.43\cdot 10^{-5}\)</td>
</tr>
<tr>
<td style="text-align: center">5</td>
<td style="text-align: center">11</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^5 \left(\frac{1}{4}\right)^{11} = 5.66 \cdot 10^{-8}\)</td>
<td style="text-align: center">4368</td>
<td style="text-align: center">\(2.47\cdot 10^{-4}\)</td>
</tr>
<tr>
<td style="text-align: center">6</td>
<td style="text-align: center">10</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^6 \left(\frac{1}{4}\right)^{10} = 1.70 \cdot 10^{-7}\)</td>
<td style="text-align: center">8008</td>
<td style="text-align: center">\(1.36\cdot 10^{-3}\)</td>
</tr>
<tr>
<td style="text-align: center">7</td>
<td style="text-align: center">9</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^7 \left(\frac{1}{4}\right)^9 = 5.09 \cdot 10^{-7}\)</td>
<td style="text-align: center">11440</td>
<td style="text-align: center">\(5.83\cdot 10^{-3}\)</td>
</tr>
<tr>
<td style="text-align: center">8</td>
<td style="text-align: center">8</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^8 \left(\frac{1}{4}\right)^8 = 1.53 \cdot 10^{-6}\)</td>
<td style="text-align: center">12870</td>
<td style="text-align: center">\(1.97\cdot 10^{-2}\)</td>
</tr>
<tr>
<td style="text-align: center">9</td>
<td style="text-align: center">7</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^9 \left(\frac{1}{4}\right)^7 = 4.58 \cdot 10^{-6}\)</td>
<td style="text-align: center">11440</td>
<td style="text-align: center">\(5.24\cdot 10^{-2}\)</td>
</tr>
<tr>
<td style="text-align: center">10</td>
<td style="text-align: center">6</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{10} \left(\frac{1}{4}\right)^6 = 1.37 \cdot 10^{-5}\)</td>
<td style="text-align: center">8008</td>
<td style="text-align: center">\(1.10\cdot 10^{-1}\)</td>
</tr>
<tr>
<td style="text-align: center">11</td>
<td style="text-align: center">5</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{11} \left(\frac{1}{4}\right)^5 = 4.12 \cdot 10^{-5}\)</td>
<td style="text-align: center">4368</td>
<td style="text-align: center">\(1.80\cdot 10^{-1}\)</td>
</tr>
<tr>
<td style="text-align: center">12</td>
<td style="text-align: center">4</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{12} \left(\frac{1}{4}\right)^4 = 1.24 \cdot 10^{-4}\)</td>
<td style="text-align: center">1820</td>
<td style="text-align: center">\(\mathbf{2.25\cdot 10^{-1}}\)</td>
</tr>
<tr>
<td style="text-align: center">13</td>
<td style="text-align: center">3</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{13} \left(\frac{1}{4}\right)^3 = 3.71 \cdot 10^{-4}\)</td>
<td style="text-align: center">560</td>
<td style="text-align: center">\(2.08\cdot 10^{-1}\)</td>
</tr>
<tr>
<td style="text-align: center">14</td>
<td style="text-align: center">2</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{14} \left(\frac{1}{4}\right)^2 = 1.11 \cdot 10^{-3}\)</td>
<td style="text-align: center">120</td>
<td style="text-align: center">\(1.34\cdot 10^{-1}\)</td>
</tr>
<tr>
<td style="text-align: center">15</td>
<td style="text-align: center">1</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{15} \left(\frac{1}{4}\right)^1 = 3.33 \cdot 10^{-3}\)</td>
<td style="text-align: center">16</td>
<td style="text-align: center">\(5.35\cdot 10^{-2}\)</td>
</tr>
<tr>
<td style="text-align: center">16</td>
<td style="text-align: center">0</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{16} \left(\frac{1}{4}\right)^0 = \mathbf{1.00 \cdot 10^{-2}}\)</td>
<td style="text-align: center">1</td>
<td style="text-align: center">\(1.00\cdot 10^{-2}\)</td>
</tr>
</tbody>
</table>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">scipy.special</span>
<span class="n">h</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">16</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">p_sequence</span> <span class="o">=</span> <span class="p">(</span><span class="mi">3</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">**</span><span class="n">h</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">**</span><span class="p">(</span><span class="mi">16</span> <span class="o">-</span> <span class="n">h</span><span class="p">)</span>
<span class="n">num_sequences</span> <span class="o">=</span> <span class="n">scipy</span><span class="p">.</span><span class="n">special</span><span class="p">.</span><span class="n">comb</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
<span class="n">p_heads_count</span> <span class="o">=</span> <span class="n">p_sequence</span> <span class="o">*</span> <span class="n">num_sequences</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">p_sequence</span><span class="p">,</span> <span class="s">'C0-s'</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s">'probability of a single sequence with this number of heads'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">p_heads_count</span><span class="p">,</span> <span class="s">'C1-o'</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s">'probability of observing this number of heads'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">'log'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'number of heads'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'probability'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
</code></pre></div></div>
<figure>
<a href="/images/unfair_coin_probs.png"><img src="/images/unfair_coin_probs.png" alt="Probabilities of observing a particular sequence with a given number of heads, and of observing a given number of heads." /></a>
<figcaption>Probabilities of observing a particular sequence with a given number of heads, and of observing a given number of heads.</figcaption>
</figure>
<h3 id="gaussian-soap-bubbles">Gaussian soap bubbles</h3>
<figure>
<img src="/images/bubbles.jpg" />
</figure>
<p>Another excellent blog post about the unintuitive behaviour of high-dimensional probability distributions is Ferenc Huszar’s <a href="https://www.inference.vc/high-dimensional-gaussian-distributions-are-soap-bubble/">‘Gaussian Distributions are Soap Bubbles’</a>. A one-dimensional Gaussian looks like bell curve: a big bump around the mode, with a tail on either side. Clearly, the bulk of the total probability mass is clumped together around the mode. In higher-dimensional spaces, this shape changes completely: the bulk of the probability mass of a spherical Gaussian distribution with unit variance in \(K\) dimensions is <strong>concentrated in a thin ‘shell’ at radius \(\sqrt{K}\)</strong>. This is known as the <em>Gaussian annulus theorem</em>.</p>
<p>For example, if we sample lots of vectors from a 100-dimensional standard Gaussian, and measure their radii, we will find that just over 84% of them are between 9 and 11, and more than 99% are between 8 and 12. Only about 0.2% have a radius smaller than 8!</p>
<p>Ferenc points out an interesting implication: <strong>high-dimensional Gaussians are very similar to uniform distributions on the sphere</strong>. This clearly isn’t true for the one-dimensional case, but it turns out that’s an exception, not the rule. Stefan Stein also discusses this implication in more detail in <a href="https://stefan-stein.github.io/posts/2020-03-07-concentration-properties-of-high-dimensional-normal-distributions/">a recent blog post</a>.</p>
<p>Where our intuition can go wrong here, is that we might underestimate how quickly a high-dimensional space grows in size as we move further away from the mode. Because of the radial symmetry of the distribution, we tend to think of all points at a given distance from the mode as similar, and we implicitly group them into sets of concentric spheres. This allows us to revert back to reasoning in one dimension, which we are more comfortable with: we think of a high-dimensional Gaussian as a distribution over these sets, rather than over individual points. What we tend to overlook, is that <strong>those sets differ wildly in size</strong>: as we move away from the mode, they grow larger very quickly. Note that this does not happen at all in 1D!</p>
<h2 id="-abstraction-and-the-curse-of-dimensionality"><a name="abstraction"></a> Abstraction and the curse of dimensionality</h2>
<figure>
<img src="/images/sand.jpg" />
</figure>
<p>The <a href="https://en.wikipedia.org/wiki/Curse_of_dimensionality">curse of dimensionality</a> is a catch-all term for various phenomena that appear very different and often counterintuitive in high-dimensional spaces. It is used to highlight poor scaling behaviour of ideas and algorithms, where one wouldn’t necessarily expect it. In the context of machine learning, it is usually used in a more narrow sense, to refer to the fact that models of high-dimensional data tend to require very large training datasets to be effective. But the curse of dimensionality manifests itself in many forms, and the unintuitive behaviour of high-dimensional probability distributions is just one of them.</p>
<p>In general, humans have lousy intuitions about high-dimensional spaces. But what exactly is going on when we get things wrong about high-dimensional distributions? In both of the motivating examples, the intuition breaks down in a similar way: if we’re not careful, <strong>we might implicitly reason about the probabilities of sets, rather than individual points</strong>, without taking into account their relative sizes, and arrive at the wrong answer. This means that we can encounter this issue for both discrete and continuous distributions.</p>
<p>We can generalise this idea of grouping points into sets of similar points, by thinking of it as <strong>‘abstraction’</strong>: rather than treating each point as a separate entity, we think of it as an instance of a particular <strong>concept</strong>, and ignore its idiosyncrasies. When we think of ‘sand’, we are rarely concerned about the characteristics of each individual grain. Similarly, in the ‘unfair coin flips’ example, we group sequences by their number of heads and tails, ignoring their order. In the case of the high-dimensional Gaussian, the natural grouping of points is based on their Euclidean distance from the mode. A more high-level example is that of natural images, where individual pixel values across localised regions of the image combine to form edges, textures, or even objects. There are usually many combinations of pixel values that give rise to the same texture, and we aren’t able to visually distinguish these particular instances unless we carefully study them side by side.</p>
<p>The following is perhaps a bit of an unfounded generalisation based on my own experience, but our brains seem hardwired to perform this kind of abstraction, so that we can reason about things in the familiar low-dimensional setting. It seems to happen unconsciously and continuously, and bypassing it requires a proactive approach.</p>
<h2 id="-typicality"><a name="typicality"></a> Typicality</h2>
<figure>
<img src="/images/typicality.jpg" />
</figure>
<p>Informally, <strong>typicality</strong> refers to the characteristics that samples from a distribution tend to exhibit on average (in expectation). In the ‘unfair coin flip’ example, a sequence with 12 heads and 4 tails is ‘typical’. A sequence with 6 heads and 10 tails is highly atypical. Typical sequences contain an average amount of information: they are not particularly surprising or (un)informative.</p>
<p>We can <a href="https://en.wikipedia.org/wiki/Typical_set">formalise this intuition</a> using the <a href="https://en.wikipedia.org/wiki/Entropy_(information_theory)">entropy</a> of the distribution: a <strong>typical set</strong> \(\mathcal{T}_\varepsilon \subset \mathcal{X}\) is a set of sequences from \(\mathcal{X}\) whose probability is close to \(2^{-H}\), where \(H\) is the entropy of the distribution that the sequences were drawn from, measured in bits:</p>
\[\mathcal{T}_\varepsilon = \{ \mathbf{x} \in \mathcal{X}: 2^{-(H + \varepsilon)} \leq p(\mathbf{x}) \leq 2^{-(H - \varepsilon)} \} .\]
<p>This means that the negative log likelihood of each such sequence is close to the entropy. Note that a distribution doesn’t have just one typical set: we can define many typical sets based on how close the probability of the sequences contained therein should be to \(2^{-H}\), by choosing different values of \(\varepsilon > 0\).</p>
<p>This concept was originally defined in an information-theoretic context, but I want to focus on machine learning, where I feel it is somewhat undervalued. It is often framed in terms of sequences sampled from <a href="https://en.wikipedia.org/wiki/Stationary_ergodic_process">stationary ergodic processes</a>, but it is useful more generally for distributions of any kind of high-dimensional data points, both continuous and discrete, regardless of whether we tend to think of them as sequences.</p>
<p>Why is this relevant to our discussion of abstraction and flawed human intuitions? As the dimensionality increases, the probability that any random sample from a distribution is part of a given typical set \(\mathcal{T}_\varepsilon\) tends towards 1. In other words, randomly drawn samples will almost always be ‘typical’, and <strong>the typical set covers most of the support of the distribution</strong> (this is a consequence of the so-called <a href="https://en.wikipedia.org/wiki/Asymptotic_equipartition_property">asymptotic equipartition property (AEP)</a>). This happens even when \(\varepsilon\) is relatively small, as long as the dimensionality is high enough. This is visualised for a 100-dimensional standard Gaussian distribution below (based on empirical measurements, to avoid having to calculate some <em>gnarly</em> 100D integrals).</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="n">N</span> <span class="o">=</span> <span class="mi">1000000</span>
<span class="n">K</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">samples</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">))</span>
<span class="n">radii</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">samples</span><span class="o">**</span><span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
<span class="n">epsilon</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">logspace</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">200</span><span class="p">)</span>
<span class="n">lo</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">K</span> <span class="o">-</span> <span class="n">epsilon</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mi">4</span><span class="p">),</span> <span class="mi">0</span><span class="p">))</span>
<span class="n">hi</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">K</span> <span class="o">+</span> <span class="n">epsilon</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mi">4</span><span class="p">))</span>
<span class="n">radius_range</span> <span class="o">=</span> <span class="n">hi</span> <span class="o">-</span> <span class="n">lo</span>
<span class="n">mass</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">((</span><span class="n">lo</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o"><</span> <span class="n">radii</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">radii</span> <span class="o"><</span> <span class="n">hi</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">epsilon</span><span class="p">))]</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">radius_range</span><span class="p">,</span> <span class="n">mass</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'Difference between the min. and max. radii inside '</span>
<span class="s">'$</span><span class="se">\\</span><span class="s">mathcal{T}_</span><span class="se">\\</span><span class="s">varepsilon$ for given $</span><span class="se">\\</span><span class="s">varepsilon$'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'Total probability mass in $</span><span class="se">\\</span><span class="s">mathcal{T}_</span><span class="se">\\</span><span class="s">varepsilon$'</span><span class="p">)</span>
</code></pre></div></div>
<figure>
<a href="/images/annulus_prob.png"><img src="/images/annulus_prob.png" alt="The total probability mass of a range of typical sets of a 100-dimensional standard Gaussian distribution, with their size measured by the difference between the minimal and maximal radii within the set (i.e. the width of the Gaussian annulus). An annulus with width 4 already contains most of the probability mass." /></a>
<figcaption>The total probability mass of a range of typical sets of a 100-dimensional standard Gaussian distribution, with their size measured by the difference between the minimal and maximal radii within the set (i.e. the width of the Gaussian annulus). An annulus with width 4 already contains most of the probability mass.</figcaption>
</figure>
<p>But this is where it gets interesting: for unimodal high-dimensional distributions, such as the multivariate Gaussian, <strong>the mode</strong> (i.e. the most likely value) <strong>usually isn’t part of the typical set</strong>. More generally, individual samples from high-dimensional (and potentially multimodal) distributions that have an unusually high likelihood are not typical, so we wouldn’t expect to see them when sampling. This can seem paradoxical, because they are by definition very ‘likely’ samples — it’s just that there are so few of them! Think about how surprising it would be to randomly sample the zero vector (or something very close to it) from a 100-dimensional standard Gaussian distribution.</p>
<p>This has some important implications: if we want to learn more about what a high-dimensional distribution looks like, <strong>studying the most likely samples is usually a bad idea</strong>. If we want to obtain a good quality sample from a distribution, subject to constraints, we should not be trying to find the single most likely one. Yet in machine learning, these are things that we do on a regular basis. In the next section, I’ll discuss a few situations where this paradox comes up in practice. For a more mathematical treatment of typicality and the curse of dimensionality, check out <a href="https://mc-stan.org/users/documentation/case-studies/curse-dims.html">this case study by Bob Carpenter</a>.</p>
<h2 id="-typicality-in-the-wild"><a name="in-the-wild"></a> Typicality in the wild</h2>
<figure>
<img src="/images/in_the_wild.jpg" />
</figure>
<p>A significant body of literature, spanning several subfields of machine learning, has sought to interpret and/or mitigate the unintuitive ways in which high-dimensional probability distributions behave. In this section, I want to highlight a few interesting papers and discuss them in relation to the concept of typicality. Note that I’ve made a selection based on what I’ve read recently, and this is not intended to be a comprehensive overview of the literature. In fact, I would appreciate pointers to other related work (papers and blog posts) that I should take a look at!</p>
<h3 id="language-modelling">Language modelling</h3>
<p>In conditional language modelling tasks, such as machine translation or image captioning, it is common to use conditional autoregressive models in combination with heuristic decoding strategies such as <a href="https://en.wikipedia.org/wiki/Beam_search">beam search</a>. The underlying idea is that we want to <strong>find the most likely sentence (i.e. the mode of the conditional distribution, ‘MAP decoding’)</strong>, but since this is intractable, we’ll settle for an approximate result instead.</p>
<p>With typicality in mind, it’s clear that this isn’t necessarily the best idea. Indeed, researchers have found that machine translation results, measured using the <a href="https://en.wikipedia.org/wiki/BLEU">BLEU metric</a>, sometimes get worse when the <em>beam width</em> is increased<sup id="fnref:sixchallenges" role="doc-noteref"><a href="#fn:sixchallenges" class="footnote" rel="footnote">2</a></sup> <sup id="fnref:analyzinguncertainty" role="doc-noteref"><a href="#fn:analyzinguncertainty" class="footnote" rel="footnote">3</a></sup>. A higher beam width gives a better, more computationally costly approximation to the mode, but not necessarily better translation results. In this case, it’s tempting to blame the metric itself, which obviously isn’t perfect, but this effect has also been observed with human ratings<sup id="fnref:tradeoff" role="doc-noteref"><a href="#fn:tradeoff" class="footnote" rel="footnote">4</a></sup>, so that cannot be the whole story.</p>
<p>A <a href="https://arxiv.org/abs/2005.10283">recent paper by Eikema & Aziz</a><sup id="fnref:mapdecoding" role="doc-noteref"><a href="#fn:mapdecoding" class="footnote" rel="footnote">5</a></sup> provides an excellent review of recent work in this space, and makes a compelling argument for <strong>MAP decoding as the culprit behind many of the pathologies that neural machine translation systems exhibit</strong> (rather than their network architectures or training methodologies). They also propose an alternative decoding strategy called <em>‘minimum Bayes risk’ (MBR) decoding</em> that takes into account the whole distribution, rather than only the mode.</p>
<p>In unconditional language modelling, beam search hasn’t caught on, but not for want of trying! Stochasticity of the result is often desirable in this setting, and the focus has been on sampling strategies instead. In <a href="https://arxiv.org/abs/1904.09751"><em>The Curious Case of Neural Text Degeneration</em></a><sup id="fnref:degeneration" role="doc-noteref"><a href="#fn:degeneration" class="footnote" rel="footnote">6</a></sup>, Holtzman et al. observe that <strong>maximising the probability leads to poor quality results that are often repetitive</strong>. Repetitive samples may not be typical, but they have high likelihoods simply because they are more predictable.</p>
<p>They compare a few different sampling strategies that interpolate between fully random sampling and <em>greedy decoding</em> (i.e. predicting the most likely token at every step in the sequence), including the <em>nucleus sampling</em> technique which they propose. The motivation for trying to find a middle ground is that models will assign low probabilities to sequences that they haven’t seen much during training, which makes <strong>low-probability predictions inherently less reliable</strong>. Therefore, we want to avoid sampling low-probability tokens <em>to some extent</em>.</p>
<p><a href="https://arxiv.org/abs/2004.10450">Zhang et al.</a><sup id="fnref:tradeoff:1" role="doc-noteref"><a href="#fn:tradeoff" class="footnote" rel="footnote">4</a></sup> frame the choice of a language model decoding strategy as a trade-off between diversity and quality. However, they find that reducing diversity only helps quality up to a point, and reducing it too much makes the results worse, as judged by human evaluators. They call this <em>‘the likelihood trap’</em>: <strong>human-judged quality of samples correlates very well with likelihood, up to an inflection point, where the correlation becomes negative</strong>.</p>
<p>In the context of typicality, this raises an interesting question: where exactly is this inflection point, and how does it relate to the typical set of the model distribution? I think it would be very interesting to determine whether the inflection point coincides exactly with the typical set, or whether it is more/less likely. Perhaps there is some degree of atypicality that human raters will tolerate? If so, can we quantify it? This wouldn’t be far-fetched: think about our preference for celebrity faces over ‘typical’ human faces, for example!</p>
<h3 id="image-modelling">Image modelling</h3>
<p>The previously mentioned <em>‘note on the evaluation of generative models’</em><sup id="fnref:anote:1" role="doc-noteref"><a href="#fn:anote" class="footnote" rel="footnote">1</a></sup> is a seminal piece of work that demonstrates several ways in which likelihoods in the image domain can be vastly misleading.</p>
<p>In <a href="https://arxiv.org/abs/1810.09136"><em>‘Do Deep Generative Models Know What They Don’t Know?’</em></a><sup id="fnref:know" role="doc-noteref"><a href="#fn:know" class="footnote" rel="footnote">7</a></sup>, Nalisnick et al. study the behaviour of likelihood-based models when presented with out-of-domain data. They observe how <strong>models can assign higher likelihoods to datasets other than their training datasets</strong>. Crucially, they show this for different classes of likelihood-based models (variational autoencoders, autoregressive models and flow-based models, see Figure 3 in the paper), which clearly demonstrates that this is an issue with the likelihood-based paradigm itself, and not with a particular model architecture or formulation.</p>
<p>Comparing images from CIFAR-10 and SVHN, two of the datasets they use, a key difference is the prevalence of textures in CIFAR-10 images, and the relative absence of such textures in SVHN images. This makes SVHN images inherently easier to predict, which partially explains why models trained on CIFAR-10 tend to assign higher likelihoods to SVHN images. Despite this, we clearly wouldn’t ever be able to sample anything that looks like an SVHN image from a CIFAR-10-trained model, because such images are not in the typical set of the model distribution (even if their likelihood is higher).</p>
<h3 id="audio-modelling">Audio modelling</h3>
<p>I don’t believe I’ve seen any recent work that studies sampling and decoding strategies for likelihood-based models in the audio domain. Nevertheless, I wanted to briefly discuss this setting because a question I often get is: <em>“why don’t you use greedy decoding or beam search to improve the quality of WaveNet samples?”</em></p>
<p>If you’ve read this far, the answer is probably clear to you by now: because <strong>audio samples outside of the typical set sound really weird</strong>! In fact, greedy decoding from a WaveNet will invariably yield complete silence, even for fairly strongly conditioned models (e.g. WaveNets for text-to-speech synthesis). In the text-to-speech case, even if you simply reduce the sampling temperature a bit too aggressively, certain consonants that are inherently noisy (such as ‘s’, ‘f’, ‘sh’ and ‘h’, the <a href="https://en.wikipedia.org/wiki/Fricative_consonant"><em>fricatives</em></a>) will start sounding very muffled. These sounds are effectively different kinds of noise, and reducing the stochasticity of this noise has an audible effect.</p>
<h3 id="anomaly-detection">Anomaly detection</h3>
<p>Anomaly detection, or out-of-distribution (OOD) detection, is the task of identifying whether a particular input could have been drawn from a given distribution. Generative models are often used for this purpose: train an explicit model on in-distribution data, and then use its likelihood estimates to identify OOD inputs.</p>
<p>Usually, the assumption is made that OOD inputs will have low likelihoods, and in-distribution inputs will have high likelihoods. However, the fact that the mode of a high-dimensional distribution usually isn’t part of its typical set clearly contradicts this. This mistaken assumption is quite pervasive. Only recently has it started to be challenged explicitly, e.g. in works by <a href="https://arxiv.org/abs/1906.02994">Nalisnick et al.</a><sup id="fnref:oodtypicality" role="doc-noteref"><a href="#fn:oodtypicality" class="footnote" rel="footnote">8</a></sup> and <a href="https://arxiv.org/abs/2006.09273">Morningstar et al.</a><sup id="fnref:dose" role="doc-noteref"><a href="#fn:dose" class="footnote" rel="footnote">9</a></sup>. Both of these works propose <strong>testing the typicality of inputs, rather than simply measuring and thresholding their likelihood</strong>.</p>
<h2 id="-the-right-level-of-abstraction"><a name="right-level"></a> The right level of abstraction</h2>
<figure>
<img src="/images/levels.jpg" />
</figure>
<p>While our intuitive notion of likelihood in high-dimensional spaces might technically be wrong, it can often be a better representation of what we actually care about. This raises the question: <strong>should we really be fitting our generative models using likelihood measured in the input space?</strong> If we were to train likelihood-based models with ‘intuitive’ likelihood, they might perform better according to perceptual metrics, because they do not have to waste capacity capturing all the idiosyncrasies of particular examples that we don’t care to distinguish anyway.</p>
<p>In fact, measuring likelihood in more abstract representation spaces has had some success in generative modelling, and I think the approach should be taken more seriously in general. In language modelling, it is common to measure likelihoods at the level of word pieces, rather than individual characters. In symbolic music modelling, recent models that operate on event-based sequences (rather than sequences with a fixed time quantum) are more effective at capturing large-scale structure<sup id="fnref:perfrnn" role="doc-noteref"><a href="#fn:perfrnn" class="footnote" rel="footnote">10</a></sup>. Some likelihood-based generative models of images separate or discard the least-significant bits of each pixel colour value, because they are less perceptually relevant, allowing model capacity to be used more efficiently<sup id="fnref:spn" role="doc-noteref"><a href="#fn:spn" class="footnote" rel="footnote">11</a></sup> <sup id="fnref:glow" role="doc-noteref"><a href="#fn:glow" class="footnote" rel="footnote">12</a></sup>.</p>
<p>But perhaps the most striking example is the recent line of work where VQ-VAE<sup id="fnref:vqvae" role="doc-noteref"><a href="#fn:vqvae" class="footnote" rel="footnote">13</a></sup> is used to <strong>learn discrete higher-level representations</strong> of perceptual signals, and generative models are then trained to maximise the likelihood in this representation space. This approach has led to models that produce images that are on par with those produced by GANs in terms of fidelity, and exceed them in terms of diversity<sup id="fnref:vqvae2" role="doc-noteref"><a href="#fn:vqvae2" class="footnote" rel="footnote">14</a></sup> <sup id="fnref:ham" role="doc-noteref"><a href="#fn:ham" class="footnote" rel="footnote">15</a></sup> <sup id="fnref:cas" role="doc-noteref"><a href="#fn:cas" class="footnote" rel="footnote">16</a></sup>. It has also led to models that are able to capture long-range temporal structure in audio signals, which even GANs had not been able to do before<sup id="fnref:challenge" role="doc-noteref"><a href="#fn:challenge" class="footnote" rel="footnote">17</a></sup> <sup id="fnref:jukebox" role="doc-noteref"><a href="#fn:jukebox" class="footnote" rel="footnote">18</a></sup>. While the current trend in representation learning is to focus on coarse-grained representations which are suitable for discriminative downstream tasks, I think it also has a very important role to play in generative modelling.</p>
<p>In the context of modelling sets with likelihood-based models, <a href="http://akosiorek.github.io/ml/2020/08/12/machine_learning_of_sets.html#what-about-those-point-processes">a recent blog post by Adam Kosiorek</a> drew my attention to point processes, and in particular, to the formula that expresses the density over ordered sequences in terms of the density over unordered sets. This formula quantifies how we need to scale probabilities across sets of different sizes to make them comparable. I think it may yet prove useful to quantify the unintuitive behaviours of likelihood-based models.</p>
<h2 id="-closing-thoughts"><a name="closing-thoughts"></a> Closing thoughts</h2>
<figure>
<img src="/images/closing_thoughts.jpg" />
</figure>
<p>To wrap up this post, here are some takeaways:</p>
<ul>
<li>
<p><strong>High-dimensional spaces</strong>, and high-dimensional probability distributions in particular, are <strong>deeply unintuitive</strong> in more ways than one. This is a well-known fact, but they still manage to surprise us sometimes!</p>
</li>
<li>
<p>The <strong>most likely samples</strong> from a high-dimensional distribution usually aren’t a very good representation of that distribution. In most situations, we probably shouldn’t be trying to find them.</p>
</li>
<li>
<p><strong>Typicality</strong> is a very useful concept to describe these unintuitive phenomena, and I think it is <strong>undervalued in machine learning</strong> — at least in the work that I’ve been exposed to.</p>
</li>
<li>
<p>A lot of work that discusses these issues (including some that I’ve highlighted in this post) <strong>doesn’t actually refer to typicality by name</strong>. I think doing so would improve our collective understanding, and shed light on links between related phenomena in different subfields.</p>
</li>
</ul>
<p>If you have any thoughts about this topic, please don’t hesitate to share them in the comments below!</p>
<p style="background-color: #eee; padding: 1em; font-size: 120%; text-align: center; border: 1px solid #ccc; border-radius: 0.5em;">
In <a href="/2020/09/01/typicality-addendum.html">an addendum to this post</a>, I explore quantitatively what happens when our intuitions fail us in high-dimensional spaces.
</p>
<p><em>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</em></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@misc{dieleman2020typicality,
author = {Dieleman, Sander},
title = {Musings on typicality},
url = {https://benanne.github.io/2020/09/01/typicality.html},
year = {2020}
}
</code></pre></div></div>
<h2 id="-acknowledgements"><a name="Acknowledgements"></a> Acknowledgements</h2>
<p>Thanks to Katie Millican, Jeffrey De Fauw and Adam Kosiorek for their valuable input and feedback on this post!</p>
<h2 id="-references"><a name="references"></a> References</h2>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:anote" role="doc-endnote">
<p>Theis, van den Oord and Bethge, “<a href="https://arxiv.org/abs/1511.01844">A note on the evaluation of generative models</a>”, International Conference on Learning Representations, 2016. <a href="#fnref:anote" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:anote:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:sixchallenges" role="doc-endnote">
<p>Koehn & Knowles, “<a href="https://arxiv.org/abs/1706.03872">Six Challenges for Neural Machine Translation</a>”, First Workshop on Neural Machine Translation, 2017. <a href="#fnref:sixchallenges" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:analyzinguncertainty" role="doc-endnote">
<p>Ott, Auli, Grangier and Ranzato, “<a href="https://arxiv.org/abs/1803.00047">Analyzing Uncertainty in Neural Machine Translation</a>”, International Conference on Machine Learning, 2018. <a href="#fnref:analyzinguncertainty" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:tradeoff" role="doc-endnote">
<p>Zhang, Duckworth, Ippolito and Neelakantan, “<a href="https://arxiv.org/abs/2004.10450">Trading Off Diversity and Quality in Natural Language Generation</a>”, arXiv, 2020. <a href="#fnref:tradeoff" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:tradeoff:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:mapdecoding" role="doc-endnote">
<p>Eikema and Aziz, “<a href="https://arxiv.org/abs/2005.10283">Is MAP Decoding All You Need? The Inadequacy of the Mode in Neural Machine Translation</a>”, arXiv, 2020. <a href="#fnref:mapdecoding" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:degeneration" role="doc-endnote">
<p>Holtzman, Buys, Du, Forbes and Choi, “<a href="https://arxiv.org/abs/1904.09751">The Curious Case of Neural Text Degeneration</a>”, International Conference on Learning Representations, 2020. <a href="#fnref:degeneration" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:know" role="doc-endnote">
<p>Nalisnick, Matsukawa, Teh, Gorur and Lakshminarayanan, “<a href="https://arxiv.org/abs/1810.09136">Do Deep Generative Models Know What They Don’t Know?</a>”, International Conference on Learnign Representations, 2019. <a href="#fnref:know" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:oodtypicality" role="doc-endnote">
<p>Nalisnick, Matuskawa, Teh and Lakshminarayanan, “<a href="https://arxiv.org/abs/1906.02994">Detecting Out-of-Distribution Inputs to Deep Generative Models Using Typicality</a>”, arXiv, 2019. <a href="#fnref:oodtypicality" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:dose" role="doc-endnote">
<p>Morningstar, Ham, Gallagher, Lakshminarayanan, Alemi and Dillon, “<a href="https://arxiv.org/abs/2006.09273">Density of States Estimation for Out-of-Distribution Detection</a>”, arXiv, 2020. <a href="#fnref:dose" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:perfrnn" role="doc-endnote">
<p>Oore, Simon, Dieleman, Eck and Simonyan, “<a href="https://arxiv.org/abs/1808.03715">This Time with Feeling: Learning Expressive Musical Performance</a>”, Neural Computing and Applications, 2020. <a href="#fnref:perfrnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:spn" role="doc-endnote">
<p>Menick and Kalchbrenner, “<a href="https://arxiv.org/abs/1812.01608">Generating High Fidelity Images with Subscale Pixel Networks and Multidimensional Upscaling</a>”, International Conference on Machine Learning, 2019. <a href="#fnref:spn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:glow" role="doc-endnote">
<p>Kingma & Dhariwal, “<a href="https://arxiv.org/abs/1807.03039">Glow: Generative flow with invertible 1x1 convolutions</a>”, Neural Information Processing Systems, 2018. <a href="#fnref:glow" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqvae" role="doc-endnote">
<p>van den Oord, Vinyals and Kavukcuoglu, “<a href="https://arxiv.org/abs/1711.00937">Neural Discrete Representation Learning</a>”, Neural Information Processing Systems, 2017. <a href="#fnref:vqvae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqvae2" role="doc-endnote">
<p>Razavi, van den Oord and Vinyals, “<a href="https://arxiv.org/abs/1906.00446">Generating Diverse High-Fidelity Images with VQ-VAE-2</a>”, Neural Information Processing Systems, 2019. <a href="#fnref:vqvae2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ham" role="doc-endnote">
<p>De Fauw, Dieleman and Simonyan, “<a href="https://arxiv.org/abs/1903.04933">Hierarchical Autoregressive Image Models with Auxiliary Decoders</a>”, arXiv, 2019. <a href="#fnref:ham" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:cas" role="doc-endnote">
<p>Ravuri and Vinyals, “<a href="https://arxiv.org/abs/1905.10887">Classification Accuracy Score for Conditional Generative Models</a>”, Neural Information Processing Systems, 2019. <a href="#fnref:cas" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:challenge" role="doc-endnote">
<p>Dieleman, van den Oord and Simonyan, “<a href="https://arxiv.org/abs/1806.10474">The challenge of realistic music generation: modelling raw audio at scale</a>”, Neural Information Processing Systems, 2018. <a href="#fnref:challenge" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:jukebox" role="doc-endnote">
<p>Dhariwal, Jun, Payne, Kim, Radford and Sutskever, “<a href="https://arxiv.org/abs/2005.00341">Jukebox: A Generative Model for Music</a>”, arXiv, 2020. <a href="#fnref:jukebox" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>If you’re training or sampling from generative models, typicality is a concept worth understanding. It sheds light on why beam search doesn’t work for autoregressive models of images, audio and video; why you can’t just threshold the likelihood to perform anomaly detection with generative models; and why high-dimensional Gaussians are “soap bubbles”. This post is a summary of my current thoughts on the topic.Generating music in the waveform domain2020-03-24T00:00:00+00:002020-03-24T00:00:00+00:00https://sander.ai/2020/03/24/audio-generation<p>In November last year, I co-presented a tutorial on <strong>waveform-based music processing with deep learning</strong> with <a href="http://www.jordipons.me/">Jordi Pons</a> and <a href="https://jongpillee.github.io/">Jongpil Lee</a> at <a href="https://ismir2019.ewi.tudelft.nl/">ISMIR 2019</a>. Jongpil and Jordi talked about music classification and source separation respectively, and I presented the last part of the tutorial, on music generation in the waveform domain. It was very well received, so I’ve decided to write it up in the form of a blog post.</p>
<div style="float: right; width: 30%;"><a href="https://ismir2019.ewi.tudelft.nl/"><img src="/images/ismir_logo.jpg" alt="ISMIR" /></a></div>
<p>ISMIR used to be my home conference when I was a PhD student working on music information retrieval, so it was great to be back for the first time in five years. With about 450 attendees (the largest edition yet), it made for a very different experience than what I’m used to with machine learning conferences like ICML, NeurIPS and ICLR, whose audiences tend to number in the thousands these days.</p>
<p>Our tutorial on the first day of the conference gave rise to plenty of interesting questions and discussions throughout, which inspired me to write some of these things down and hopefully provide a basis to continue these discussions online. Note that I will only be covering music generation in this post, but Jordi and Jongpil are working on blog posts about their respective parts. I will share them here when they are published. In the meantime, <strong>the slide deck we used includes all three parts and is now available on <a href="https://zenodo.org/record/3529714#.XdBi0dv7Sf5">Zenodo (PDF)</a> and on <a href="https://docs.google.com/presentation/d/1_ezZXDkyhp9USAYMc5oKJCkUrUhBfo-Di8H8IfypGBM/edit#slide=id.g647f5a8648_0_57">Google slides</a></strong>. I’ve also added a few things to this post that I’ve thought of since giving the tutorial, and some new work that has come out since.</p>
<p>This is also an excellent opportunity to revive my blog, which has lain dormant for the past four years. I have taken the time to update the blog software, so if anything looks odd, that may be why. Please let me know so I can fix it!</p>
<figure>
<a href="/images/ismir_2019_photo.jpeg"><img src="/images/ismir_2019_photo.jpeg" alt="Presenting our tutorial session at ISMIR 2019 in Delft, The Netherlands." /></a>
<figcaption>Presenting our tutorial session at ISMIR 2019 in Delft, The Netherlands. Via <a href="https://twitter.com/ismir2019/status/1191341227825934336">ISMIR2019 on Twitter</a>.</figcaption>
</figure>
<h2 id="-overview"><a name="overview"></a> Overview</h2>
<p>This blog post is divided into a few different sections. I’ll try to motivate why modelling music in the waveform domain is an interesting problem. Then I’ll give an overview of generative models, the various flavours that exist, and some important ways in which they differ from each other. In the next two sections I’ll attempt to cover the state of the art in both likelihood-based and adversarial models of raw music audio. Finally, I’ll raise some observations and discussion points. If you want to skip ahead, just click the section title below to go there.</p>
<ul>
<li><em><a href="#motivation">Motivation</a></em></li>
<li><em><a href="#generative-models">Generative models</a></em></li>
<li><em><a href="#likelihood-based-models">Likelihood-based models of waveforms</a></em></li>
<li><em><a href="#adversarial-models">Adversarial models of waveforms</a></em></li>
<li><em><a href="#discussion">Discussion</a></em></li>
<li><em><a href="#conclusion">Conclusion</a></em></li>
<li><em><a href="#references">References</a></em></li>
</ul>
<p>Note that this blog post is not intended to provide an exhaustive overview of all the published research in this domain – I have tried to make a selection and I’ve inevitably left out some great work. <strong>Please don’t hesitate to suggest relevant work in the comments section!</strong></p>
<h2 id="-motivation"><a name="motivation"></a> Motivation</h2>
<h3 id="why-audio">Why audio?</h3>
<p>Music generation has traditionally been studied in the <strong>symbolic domain</strong>: the output of the generative process could be a musical score, a sequence of <a href="https://en.wikipedia.org/wiki/MIDI">MIDI events</a>, a simple melody, a sequence of chords, a textual representation<sup id="fnref:folkrnn" role="doc-noteref"><a href="#fn:folkrnn" class="footnote" rel="footnote">1</a></sup> or some other higher-level representation. The physical process through which sound is produced is abstracted away. This dramatically reduces the amount of information that the models are required to produce, which makes the modelling problem more tractable and allows for lower-capacity models to be used effectively.</p>
<p>A very popular representation is the so-called <em>piano roll</em>, which dates back to the player pianos of the early 20th century. Holes were punched into a roll of paper to indicate which notes should be played at which time. This representation survives in digital form today and is commonly used in music production. Much of the work on music generation using machine learning has made use of (some variant of) this representation, because it allows for capturing performance-specific aspects of the music without having to model the sound.</p>
<figure class="half">
<a href="/images/player_piano.jpg"><img src="/images/player_piano.jpg" alt="Player piano with a physical piano roll inside." /></a>
<a href="/images/piano_roll.jpg"><img src="/images/piano_roll.jpg" alt="Modern incarnation of a piano roll." /></a>
<figcaption><strong>Left:</strong> player piano with a physical piano roll inside. <strong>Right:</strong> modern incarnation of a piano roll.</figcaption>
</figure>
<p>Piano rolls are great for piano performances, because they are able to exactly capture the <em>timing</em>, <em>pitch</em> and <em>velocity</em> (i.e. how hard a piano key is pressed, which is correlated with loudness, but not equivalent to it) of the notes. They are able to very accurately represent piano music, because they cover all the “degrees of freedom” that a performer has at their disposal. However, most other instruments have many more degrees of freedom: think about all the various ways you can play a note on the guitar, for example. You can decide which string to use, where to pick, whether to bend the string or not, play vibrato, … you could even play harmonics, or use two-hand tapping. Such a vast array of different playing techniques endows the performer with a lot more freedom to vary the sound that the instrument produces, and coming up with a high-level representation that can accurately capture all this variety is much more challenging. In practice, a lot of this detail is ignored and a simpler representation is often used when generating music for these instruments.</p>
<p>Modelling the sound that an instrument produces is much more difficult than modelling (some of) the parameters that are controlled by the performer, but it frees us from having to manually design high-level representations that accurately capture all these parameters. Furthermore, it allows our models to capture variability that is beyond the performer’s control: the idiosyncracies of individual instruments, for example (no two violins sound exactly the same!), or the parameters of the recording setup used to obtain the training data for our models. It also makes it possible to model ensembles of instruments, or other sound sources altogether, without having to fundamentally change anything about the model apart from the data it is trained on.</p>
<p>Digital audio representations require a reasonably high bit rate to achieve acceptable fidelity however, and modelling all these bits comes with a cost. <strong>Music audio models will necessarily have to have a much higher capacity than their symbolic counterparts</strong>, which implies higher computational requirements for model training.</p>
<h3 id="why-waveforms"><a name="why-waveforms"></a>Why waveforms?</h3>
<p>Digital representations of sound come in many shapes and forms. For reproduction, sound is usually stored by encoding the shape of the waveform as it changes over time. For analysis however, we often make use of <strong><a href="https://en.wikipedia.org/wiki/Spectrogram">spectrograms</a></strong>, both for computational methods and for visual inspection by humans. A spectrogram can be obtained from a waveform by computing the Fourier transform of overlapping windows of the signal, and stacking the results into a 2D array. This shows the <strong>local frequency content of the signal over time</strong>.</p>
<p>Spectrograms are complex-valued: they represent both the amplitude and the phase of different frequency components at each point in time. Below is a visualisation of a magnitude spectrogram and its corresponding phase spectrogram. While the magnitude spectrogram clearly exhibits a lot of structure, with sustained frequencies manifesting as horizontal lines and harmonics showing up as parallel horizontal lines, the phase spectrogram looks a lot more random.</p>
<figure>
<a href="/images/spectrogram_magnitude.png"><img src="/images/spectrogram_magnitude.png" alt="Magnitude spectrogram of a piano recording." /></a>
<a href="/images/spectrogram_phase.png"><img src="/images/spectrogram_phase.png" alt="Phase spectrogram of a piano recording." /></a>
<figcaption><strong>Top:</strong> magnitude spectrogram of a piano recording. <strong>Bottom:</strong> the corresponding phase spectrogram.</figcaption>
</figure>
<p>When extracting information from audio signals, it turns out that we can often just <strong>discard the phase component</strong>, because it is not informative for most of the things we could be interested in. In fact, this is why the magnitude spectrogram is often referred to simply as “the spectrogram”. When generating sound however, phase is very important because it meaningfully affects our perception. Listen below to an original excerpt of a piano piece, and a corresponding excerpt where the original phase has been replaced by random uniform phase information. Note how the harmony is preserved, but the timbre changes completely.</p>
<figure class="half">
<audio controls="" src="/files/original_phase.wav"><a href="/files/original_phase.wav">Audio with original phase</a></audio>
<audio controls="" src="/files/random_phase.wav"><a href="/files/random_phase.wav">Audio with random phase</a></audio>
<figcaption><strong>Left:</strong> excerpt with original phase. <strong>Right:</strong> the same excerpt with random phase.</figcaption>
</figure>
<p>The phase component of a spectrogram is tricky to model for a number of reasons:</p>
<ul>
<li>it is an <strong>angle</strong>: \(\phi \in [0, 2 \pi)\) and it wraps around;</li>
<li>it becomes <strong>effectively random</strong> as the magnitude tends towards 0, because noise starts to dominate;</li>
<li>absolute phase is less meaningful, but <strong>relative phase differences over time matter perceptually</strong>.</li>
</ul>
<p>If we model waveforms directly, we are implicitly modelling their phase as well, but we don’t run into these issues that make modelling phase so cumbersome. There are other strategies to avoid these issues, some of which I will <a href="#alternatives">discuss later</a>, but <strong>waveform modelling currently seems to be the dominant approach in the generative setting</strong>. This is particularly interesting because magnitude spectrograms are by far the most common representation used for discriminative models of audio.</p>
<h3 id="discretising-waveforms">Discretising waveforms</h3>
<p>When representing a waveform digitally, we need to <strong>discretise it in both time and amplitude</strong>. This is referred to as <a href="https://en.wikipedia.org/wiki/Pulse-code_modulation">pulse code modulation (PCM)</a>. Because audio waveforms are effectively band-limited (humans cannot perceive frequencies above ~20 kHz), the <a href="https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem">sampling theorem</a> tells us that we can discretise the waveform in time without any loss of information, as long as the sample rate is high enough (twice the highest frequency). This is why CD quality audio has a sample rate of 44.1 kHz. Much lower sample rates result in an audible loss of fidelity, but since the resulting discrete sequences also end up being much shorter, a compromise is often struck in the context of generative modelling to reduce computational requirements. Most models from literature use sample rates of 16 or 24 kHz.</p>
<figure>
<a href="/images/digital_waveform.gif"><img style="width: 100%; border: 1px solid #eee;" src="/images/digital_waveform.gif" alt="Digital waveform." /></a>
<figcaption>Digital waveform. The individual samples become visible as the zoom level increases. Figure taken from <a href="https://deepmind.com/blog/article/wavenet-generative-model-raw-audio">the original WaveNet blog post</a>.</figcaption>
</figure>
<p>When we also quantise the amplitude, some loss of fidelity is inevitable. CD quality uses 16 bits per sample, representing 2<sup>16</sup> equally spaced quantisation levels. If we want to use fewer bits, we can use logarithmically spaced quantisation levels instead to account for our nonlinear perception of loudness. This <strong><a href="https://en.wikipedia.org/wiki/%CE%9C-law_algorithm">“mu-law companding”</a></strong> will result in a smaller perceived loss of fidelity than if the levels were equally spaced.</p>
<h2 id="-generative-models"><a name="generative-models"></a> Generative models</h2>
<p>Given a dataset \(X\) of examples \(x \in X\), which we assume to have been drawn independently from some underlying distribution \(p_X(x)\), a generative model can learn to approximate this distribution \(p_X(x)\). Such a model could be used to generate new samples that look like they could have been part of the original dataset. We distinguish <em>implicit</em> and <em>explicit</em> generative models: an implicit model can produce new samples \(x \sim p_X(x)\), but cannot be used to infer the likelihood of an example (i.e. we cannot tractably compute \(p_X(x)\) given \(x\)). If we have an explicit model, we can do this, though sometimes only up to an unknown normalising constant.</p>
<h3 id="conditional-generative-models">Conditional generative models</h3>
<p>Generative models become more practically useful when we can exert some influence over the samples we draw from them. We can do this by providing a <strong>conditioning signal</strong> \(c\), which contains side information about the kind of samples we want to generate. The model is then fit to the conditional distribution \(p_X(x \vert c)\) instead of \(p_X(x)\).</p>
<p>Conditioning signals can take many shapes or forms, and it is useful to distinguish different levels of information content. The generative modelling problem becomes easier if the conditioning signal \(c\) is richer, because it reduces uncertainty about \(x\). We will refer to conditioning signals with low information content as <em>sparse conditioning</em>, and those with high information content as <em>dense conditioning</em>. Examples of conditioning signals in the image domain and the music audio domain are shown below, ordered according to density.</p>
<figure>
<img src="/images/sparse-dense-conditioning.svg" alt="Examples of sparse and dense conditioning signals in the image domain (top) and the music audio domain (bottom)." />
<figcaption>Examples of sparse and dense conditioning signals in the image domain (top) and the music audio domain (bottom).</figcaption>
</figure>
<p>Note that the density of a conditioning signal is often correlated with its level of abstraction: high-level side information tends to be more sparse. Low-level side information isn’t necessarily dense, though. For example, we could condition a generative model of music audio on a low-dimensional vector that captures the overall timbre of an instrument. This is a low-level aspect of the audio signal, but it constitutes a sparse conditioning signal.</p>
<h3 id="likelihood-based-models">Likelihood-based models</h3>
<p>Likelihood-based models directly parameterise \(p_X(x)\). The parameters \(\theta\) are then fit by maximising the likelihood of the data under the model:</p>
\[\mathcal{L}_\theta(x) = \sum_{x \in X} \log p_X(x|\theta) \quad \quad \theta^* = \arg \max_\theta \mathcal{L}_\theta(x) .\]
<p>Note that this is typically done in the log-domain because it simplifies computations and improves numerical stability. Because the model directly parameterises \(p_X(x)\), we can <strong>easily infer the likelihood of any</strong> \(x\), so we get an explicit model. Three popular flavours of likelihood-based models are autoregressive models, flow-based models and variational autoencoders. The following three subsections provide a brief overview of each.</p>
<h3 id="autoregressive-models">Autoregressive models</h3>
<p>In an autoregressive model, we assume that our examples \(x \in X\) can be treated as sequences \(\{x_i\}\). We then factorise the distribution into a product of conditionals, using the <a href="https://en.wikipedia.org/wiki/Chain_rule_(probability)">chain rule of probability</a>:</p>
\[p_X(x) = \prod_i p(x_i \vert x_{<i}) .\]
<p>These conditional distributions are typically scalar-valued and much easier to model. Because we further assume that the distribution of the sequence elements is stationary, we can share parameters and use the same model for all the factors in this product.</p>
<p>For audio signals, this is a very natural thing to do, but we can also do this for other types of structured data by arbitrarily choosing an order (e.g. raster scan order for images, as in PixelRNN<sup id="fnref:pixelrnn" role="doc-noteref"><a href="#fn:pixelrnn" class="footnote" rel="footnote">2</a></sup> and PixelCNN<sup id="fnref:pixelcnn" role="doc-noteref"><a href="#fn:pixelcnn" class="footnote" rel="footnote">3</a></sup>).</p>
<p>Autoregressive models are attractive because they are able to <strong>accurately capture correlations between the different elements</strong> \(x_i\) in a sequence, and they allow for fast inference (i.e. computing \(p_X(x)\) given \(x\)). Unfortunately they tend to be <strong>slow to sample from</strong>, because samples need to be drawn sequentially from the conditionals for each position in the sequence.</p>
<h3 id="flow-based-models">Flow-based models</h3>
<p>Another strategy for constructing a likelihood-based model is to use the <strong><a href="https://en.wikipedia.org/wiki/Probability_density_function#Function_of_random_variables_and_change_of_variables_in_the_probability_density_function">change of variables theorem</a></strong> to transform \(p_X(x)\) into a simple, factorised distribution \(p_Z(z)\) (standard Gaussian is a popular choice) using an invertible mapping \(x = g(z)\):</p>
\[p_X(x) = p_Z(z) \cdot |\det J|^{-1} \quad \quad J = \frac{dg(z)}{dz}.\]
<p>Here, \(J\) is the Jacobian of \(g(z)\). Models that use this approach are referred to as normalising flows or flow-based models<sup id="fnref:nice" role="doc-noteref"><a href="#fn:nice" class="footnote" rel="footnote">4</a></sup><sup id="fnref:realnvp" role="doc-noteref"><a href="#fn:realnvp" class="footnote" rel="footnote">5</a></sup>. They are fast both for inference and sampling, but the <strong>requirement for \(g(z)\) to be invertible significantly constrains the model architecture</strong>, and it makes them less parameter-efficient. In other words: flow-based models need to be quite large to be effective.</p>
<p>For an in-depth treatment of flow-based models, I recommend Eric Jang’s <a href="https://blog.evjang.com/2018/01/nf1.html">two-part blog post</a> on the subject, and <a href="https://arxiv.org/abs/1912.02762">Papamakarios et al.’s excellent review paper</a>.</p>
<h3 id="variational-autoencoders-vaes">Variational autoencoders (VAEs)</h3>
<p>By far the most popular class of likelihood-based generative models, I can’t avoid mentioning variational<sup id="fnref:vaerezende" role="doc-noteref"><a href="#fn:vaerezende" class="footnote" rel="footnote">6</a></sup> autoencoders<sup id="fnref:vaekingma" role="doc-noteref"><a href="#fn:vaekingma" class="footnote" rel="footnote">7</a></sup> – but <strong>in the context of waveform modelling, they are probably the least popular approach</strong>. In a VAE, we jointly learn two neural networks: an <em>inference network</em> \(q(z \vert x)\) learns to probabilistically map examples \(x\) into a latent space, and a <em>generative network</em> \(p(x \vert z)\) learns the distribution of the data conditioned on a latent representation \(z\). These are trained to maximise a lower bound on \(p_X(x)\), called the ELBO (Evidence Lower BOund), because computing \(p_X(x)\) given \(x\) (exact inference) is not tractable.</p>
<p>Typical VAEs assume a factorised distribution for \(p(x \vert z)\), which limits the extent to which they can capture dependencies in the data. While this is often an acceptable trade-off, in the case of waveform modelling it turns out to be a problematic restriction in practice. I believe this is why not a lot of work has been published that takes this approach (if you know of any, please point me to it). VAEs can also have more powerful decoders with fewer assumptions (autoregressive decoders, for example), but this may introduce other issues such as posterior collapse<sup id="fnref:pc" role="doc-noteref"><a href="#fn:pc" class="footnote" rel="footnote">8</a></sup>.</p>
<p>To learn more about VAEs, check out <a href="https://jaan.io/what-is-variational-autoencoder-vae-tutorial/">Jaan Altosaar’s tutorial</a>.</p>
<h3 id="adversarial-models">Adversarial models</h3>
<p>Generative Adversarial Networks<sup id="fnref:gans" role="doc-noteref"><a href="#fn:gans" class="footnote" rel="footnote">9</a></sup> (GANs) take a very different approach to capturing the data distribution. Two networks are trained simultaneously: a <em>generator</em> \(G\) attempts to produce examples according to the data distribution \(p_X(x)\), given latent vectors \(z\), while a <em>discriminator</em> \(D\) attempts to tell apart generated examples and real examples. In doing so, the discriminator provides a learning signal for the generator which enables it to better match the data distribution. In the original formulation, the loss function is as follows:</p>
\[\mathcal{L}(x) = \mathbb{E}_x[\log D(x)] + \mathbb{E}_z[log(1 - D(G(z)))] .\]
<p>The generator is trained to minimise this loss, whereas the discriminator attempts to maximise it. This means the training procedure is a <strong>two-player minimax game</strong>, rather than an optimisation process, as it is for most machine learning models. Balancing this game and keeping training stable has been one of the main challenges for this class of models. Many alternative formulations have been proposed to address this.</p>
<p>While adversarial and likelihood-based models are both ultimately trying to model \(p_X(x)\), they approach this target from very different angles. As a result, <strong>GANs tend to be better at producing realistic examples, but worse at capturing the full diversity of the data distribution</strong>, compared to likelihood-based models.</p>
<h3 id="more-exotic-flavours">More exotic flavours</h3>
<p>Many other strategies to learn models of complicated distributions have been proposed in literature. While research on waveform generation has chiefly focused on the two dominant paradigms of likelihood-based and adversarial models, some of these alternatives may hold promise in this area as well, so I want to mention a few that I’ve come across.</p>
<ul>
<li>
<p><strong>Energy-based models</strong> measure the “energy” of examples, and are trained by fitting the model parameters so that examples coming from the dataset have low energy, whereas all other configurations of inputs have high energy. This amounts to fitting an unnormalised density. A nice recent example is <a href="https://openai.com/blog/energy-based-models/">the work by Du & Mordatch at OpenAI</a><sup id="fnref:energy" role="doc-noteref"><a href="#fn:energy" class="footnote" rel="footnote">10</a></sup>. Energy-based models have been around for a very long time though, and one could argue that likelihood-based models are a special case.</p>
</li>
<li>
<p><strong>Optimal transport</strong> is another approach to measure the discrepancy between probability distributions, which has served as inspiration for new variants of generative adversarial networks<sup id="fnref:wgan" role="doc-noteref"><a href="#fn:wgan" class="footnote" rel="footnote">11</a></sup> and autoencoders<sup id="fnref:swa" role="doc-noteref"><a href="#fn:swa" class="footnote" rel="footnote">12</a></sup>.</p>
</li>
<li>
<p><strong>Autoregressive implicit quantile networks</strong><sup id="fnref:aiqn" role="doc-noteref"><a href="#fn:aiqn" class="footnote" rel="footnote">13</a></sup> use a similar network architecture as likelihood-based autoregressive models, but they are trained using the quantile regression loss, rather than maximimum likelihood.</p>
</li>
<li>
<p>Two continuous distributions can be matched by minimising the L2 distance between the gradients of the density functions with respect to their inputs: \(\mathcal{L}(x) = \mathbb{E} [\vert\vert \nabla_x \log p_X(x) - \nabla_y \log p_Y(y) \vert\vert ^2]\). This is called <strong>score matching</strong><sup id="fnref:scorematching" role="doc-noteref"><a href="#fn:scorematching" class="footnote" rel="footnote">14</a></sup> and some recent works have revisited this idea for density estimation<sup id="fnref:ssm" role="doc-noteref"><a href="#fn:ssm" class="footnote" rel="footnote">15</a></sup> and generative modelling<sup id="fnref:scorebased" role="doc-noteref"><a href="#fn:scorebased" class="footnote" rel="footnote">16</a></sup>.</p>
</li>
<li>
<p>Please share any others that I haven’t mentioned in the comments!</p>
</li>
</ul>
<h3 id="mode-covering-vs-mode-seeking-behaviour">Mode-covering vs. mode-seeking behaviour</h3>
<p>An important consideration when determining which type of generative model is appropriate for a particular application, is the degree to which it is <em>mode-covering</em> or <em>mode-seeking</em>. When a model does not have enough capacity to capture all the variability in the data, different compromises can be made. If all examples should be reasonably likely under the model, it will have to overgeneralise and put probability mass on interpolations of examples that may not be meaningful (mode-covering). If there is no such requirement, the probability mass can be focused on a subset of examples, but then some parts of the distribution will be ignored by the model (mode-seeking).</p>
<figure>
<a href="/images/mode_seeking_covering.png"><img src="/images/mode_seeking_covering.png" alt="Illustration of mode-seeking and mode-covering behaviour in model fitting." /></a>
<figcaption>Illustration of mode-seeking and mode-covering behaviour in model fitting. The blue density represents the data distribution. The green density is our model, which is a single Gaussian. Because the data distribution is multimodal, our model does not have enough capacity to accurately capture it.</figcaption>
</figure>
<p><strong>Likelihood-based models are usually mode-covering</strong>. This is a consequence of the fact that they are fit by maximising the joint likelihood of the data. <strong>Adversarial models on the other hand are typically mode-seeking</strong>. A lot of ongoing research is focused on making it possible to control the trade-off between these two behaviours directly, without necessarily having to switch the class of models that are used.</p>
<p>In general, mode-covering behaviour is desirable in sparsely conditioned applications, where we want diversity or we expect a certain degree of “creativity” from the model. Mode-seeking behaviour is more useful in densely-conditioned settings, where most of the variability we care about is captured in the conditioning signal, and we favour realism of the generated output over diversity.</p>
<h2 id="-likelihood-based-models-of-waveforms"><a name="likelihood-based-models"></a> Likelihood-based models of waveforms</h2>
<p>In this section, I’ll try to summarise some of the key results from the past four years obtained with likelihood-based models of waveforms. While this blog post is supposed to be about music, note that many of these developments were initially targeted at generating speech, so inevitably I will also be talking about some work in the text-to-speech (TTS) domain. I recommend reading the associated papers and/or blog posts to find out more about each of these works.</p>
<h3 id="wavenet--samplernn">WaveNet & SampleRNN</h3>
<figure>
<a href="/images/wavenet.gif"><img style="display: block; margin: auto;" src="/images/wavenet.gif" alt="Wavenet sampling procedure." /></a>
<figcaption>Animation showing sampling from a WaveNet model. The model predicts the distribution of potential signal values for each timestep, given past signal values.</figcaption>
</figure>
<p>WaveNet<sup id="fnref:wavenet" role="doc-noteref"><a href="#fn:wavenet" class="footnote" rel="footnote">17</a></sup> and SampleRNN<sup id="fnref:samplernn" role="doc-noteref"><a href="#fn:samplernn" class="footnote" rel="footnote">18</a></sup> are <strong>autoregressive models of raw waveforms</strong>. While WaveNet is a convolutional neural network, SampleRNN uses a stack of recurrent neural networks. Both papers appeared on arXiv in late 2016 with only a few months in between, signalling that autoregressive waveform-based audio modelling was an idea whose time had come. Before then, this idea had not been seriously considered, as modelling long-term correlations in sequences across thousands of timesteps did not seem feasible with the tools that were available at that point. Furthermore, discriminative models of audio all used spectral input representations, with only a few works investigating the use of raw waveforms in this setting (and usually with worse results).</p>
<p>Although these models have their flaws (including slow sampling due to autoregressivity, and a lack of interpretability w.r.t. what actually happens inside the network), I think they constituted an important <em>existence proof</em> that encouraged further research into waveform-based models.</p>
<p>WaveNet’s strategy to deal with long-term correlations is to use <em>dilated convolutions</em>: successive convolutional layers use filters with gaps between their inputs, so that the connectivity pattern across many layers forms a tree structure (see figure above). This enables rapid growth of the receptive field, which means that <strong>a WaveNet with only a few layers can learn dependencies across many timesteps</strong>. Note that the convolutions used in WaveNet are causal (no connectivity from future to past), which forces the model to learn to predict what values the signal could take at each position in time.</p>
<p>SampleRNN’s strategy is a bit different: multiple RNNs are stacked on top of each other, with each running at a different frequency. Higher-level RNNs update less frequently, which means they can more easily capture long-range correlations and learn high-level features.</p>
<p>Both models demonstrated excellent text-to-speech results, surpassing the state of the art at the time (concatenative synthesis, for most languages) in terms of naturalness. Both models were also applied to (piano) music generation, which constituted a nice demonstration of the promise of music generation in the waveform domain, but they were clearly limited in their ability to capture longer-term musical structure.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>WaveNet</strong>: <a href="https://arxiv.org/abs/1609.03499">paper</a> - <a href="https://deepmind.com/blog/article/wavenet-generative-model-raw-audio">blog post</a><br />
<strong>SampleRNN</strong>: <a href="https://arxiv.org/abs/1612.07837">paper</a> - <a href="https://soundcloud.com/samplernn/sets">samples</a>
</p>
<h3 id="parallel-wavenet--clarinet">Parallel WaveNet & ClariNet</h3>
<p>Sampling from autoregressive models of raw audio can be quite slow and impractical. To address this issue, Parallel WaveNet<sup id="fnref:parallelwavenet" role="doc-noteref"><a href="#fn:parallelwavenet" class="footnote" rel="footnote">19</a></sup> uses <em>probability density distillation</em> to train a model from which samples can be drawn in a single feed-forward pass. This requires a trained autoregressive WaveNet, which functions as a teacher, and an inverse autoregressive flow (IAF) model which acts as the student and learns to mimic the teacher’s predictions.</p>
<p>While an autoregressive model is slow to sample from, inferring the likelihood of a given example (and thus, maximum-likelihood training) can be done in parallel. <strong>For an inverse autoregressive flow, it’s the other way around: sampling is fast, but inference is slow</strong>. Since most practical applications rely on sampling rather than inference, such a model is often better suited. IAFs are hard to train from scratch though (because that requires inference), and the probability density distillation approach makes training them tractable.</p>
<p>Due to the nature of the probability density distillation objective, the student will end up matching the teacher’s predictions in a way that minimises the <em>reverse</em> KL divergence. This is quite unusual: likelihood-based models are typically trained to minimise the forward KL divergence instead, which is equivalent to maximising the likelihood (and minimising the reverse KL is usually intractable). While minimising the forward KL leads to mode-covering behaviour, <strong>minimising the reverse KL will instead lead to mode-seeking behaviour</strong>, which means that the model may end up ignoring certain modes in the data distribution.</p>
<p>In the text-to-speech (TTS) setting, this may actually be exactly what we want: given an excerpt of text, we want the model to generate a realistic utterance corresponding to that excerpt, but we aren’t particularly fussed about being able to generate every possible variation – one good-sounding utterance will do. This is a setting where <strong>realism is clearly more important than diversity</strong>, because all the diversity that we care about is already captured in the conditioning signal that we provide. This is usually the setting where adversarial models excel, because of their inherent mode-seeking behaviour, but using probability density distillation we can also train likelihood-based models this way.</p>
<p>To prevent the model from collapsing, parallel WaveNet uses a few additional loss terms to encourage the produced waveforms to resemble speech (such as a loss on the average power spectrum).</p>
<p>If we want to do music generation, we will typically care more about diversity because the conditioning signals we provide to the model are weaker. I believe this is why we haven’t really seen the Parallel WaveNet approach catch on outside of TTS.</p>
<p>ClariNet<sup id="fnref:clarinet" role="doc-noteref"><a href="#fn:clarinet" class="footnote" rel="footnote">20</a></sup> was introduced as a variant of Parallel WaveNet which uses a Gaussian inverse autoregressive flow. The Gaussian assumption makes it possible to compute the reverse KL in closed form, rather than having to approximate it by sampling, which stabilises training.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Parallel WaveNet</strong>: <a href="https://arxiv.org/abs/1711.10433">paper</a> - <a href="https://deepmind.com/blog/article/high-fidelity-speech-synthesis-wavenet">blog post 1</a> - <a href="https://deepmind.com/blog/article/wavenet-launches-google-assistant">blog post 2</a><br />
<strong>ClariNet</strong>: <a href="https://arxiv.org/abs/1807.07281">paper</a> - <a href="https://clarinet-demo.github.io/">samples</a>
</p>
<h3 id="flow-based-models-waveglow-flowavenet-waveflow-blow">Flow-based models: WaveGlow, FloWaveNet, WaveFlow, Blow</h3>
<p>Training an IAF with probability density distillation isn’t the only way to train a flow-based model: most can be trained by maximum likelihood instead. In that case, the models will be encouraged to capture all the modes of the data distribution. This, in combination with their relatively low parameter efficiency (due to the invertibility requirement), means that they might need to be a bit larger to be effective. On the other hand, <strong>they allow for very fast sampling because all timesteps can be generated in parallel</strong>, so while the computational cost may be higher, sampling will still be faster in practice. Another advantage is that no additional loss terms are required to prevent collapse.</p>
<p>WaveGlow<sup id="fnref:waveglow" role="doc-noteref"><a href="#fn:waveglow" class="footnote" rel="footnote">21</a></sup> and FloWaveNet<sup id="fnref:flowavenet" role="doc-noteref"><a href="#fn:flowavenet" class="footnote" rel="footnote">22</a></sup>, both originally published in late 2018, are flow-based models of raw audio conditioned on mel-spectrograms, which means they can be used as <em>vocoders</em>. Because of the limited parameter efficiency of flow-based models, I suspect that it would be difficult to use them for music generation in the waveform domain, where conditioning signals are much more sparse – but they could of course be used to render mel-spectrograms generated by some other model into waveforms (more on that later).</p>
<p>WaveFlow<sup id="fnref:waveflow" role="doc-noteref"><a href="#fn:waveflow" class="footnote" rel="footnote">23</a></sup> (with an F instead of a G) is a more recent model that improves parameter efficiency by combining the flow-based modelling approach with partial autoregressivity to model local signal structure. This allows for a trade-off between sampling speed and model size. Blow<sup id="fnref:blow" role="doc-noteref"><a href="#fn:blow" class="footnote" rel="footnote">24</a></sup> is a flow-based model of waveforms for non-parallel voice conversion.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>WaveGlow</strong>: <a href="https://arxiv.org/abs/1811.00002">paper</a> - <a href="https://github.com/NVIDIA/waveglow">code</a> - <a href="https://nv-adlr.github.io/WaveGlow">samples</a><br />
<strong>FloWaveNet</strong>: <a href="https://arxiv.org/abs/1811.02155">paper</a> - <a href="https://github.com/ksw0306/FloWaveNet">code</a> - <a href="https://ksw0306.github.io/flowavenet-demo/">samples</a><br />
<strong>WaveFlow</strong>: <a href="https://arxiv.org/abs/1912.01219">paper</a> - <a href="https://waveflow-demo.github.io/">samples</a><br />
<strong>Blow</strong>: <a href="https://papers.nips.cc/paper/8904-blow-a-single-scale-hyperconditioned-flow-for-non-parallel-raw-audio-voice-conversion">paper</a> - <a href="https://github.com/joansj/blow">code</a> - <a href="https://blowconversions.github.io/">samples</a>
</p>
<h3 id="hierarchical-wavenets">Hierarchical WaveNets</h3>
<p>For the purpose of music generation, <strong>WaveNet is limited by its ability to capture longer-term signal structure</strong>, as previously stated. In other words: while it is clearly able to capture local signal structure very well (i.e. the timbre of an instrument), it isn’t able to model the evolution of chord progressions and melodies over longer time periods. This makes the outputs produced by this model sound rather improvisational, to put it nicely.</p>
<p>This may seem counterintuitive at first: the tree structure of the connectivity between the layers of the model should allow for a very rapid growth of its receptive field. So if you have a WaveNet model that captures up to a second of audio at a time (more than sufficient for TTS), stacking a few more dilated convolutional layers on top should suffice to grow the receptive field by several orders of magnitude (up to many minutes). At that point, the model should be able to capture any kind of meaningful musical structure.</p>
<p>In practice, however, we need to train models on excerpts of audio that are at least as long as the longest-range correlations that we want to model. So while the depth of the model has to grow only logarithmically as we increase the desired receptive field, <strong>the computational and memory requirements for training do in fact grow linearly</strong>. If we want to train a model that can learn about musical structure across tens of seconds, that will necessarily be an order of magnitude more expensive – and WaveNets that generate music already have to be quite large as it is, even with a receptive field of just one second, because <strong>music is harder to model than speech</strong>. Note also that one second of audio corresponds to a sequence of 16000 timesteps at 16 kHz, so even at a scale of seconds, we are already modelling very long sequences.</p>
<p>In 10 years, the hardware we would need to train a WaveNet with a receptive field of 30 seconds (or almost half a million timesteps at 16 kHz) may just fit in a desktop computer, so we could just wait until then to give it a try. But if we want to train such models today, we need a different strategy. If we could train separate models to capture structure at different timescales, we could have a dedicated model that focuses on capturing longer-range correlations, without having to also model local signal structure. This seems feasible, seeing as models of high-level representations of music (i.e. scores or MIDI) clearly do a much better job of capturing long-range musical structure already.</p>
<p>We can approach this as a <strong>representation learning</strong> problem: to decouple learning of local and large-scale structure, we need to extract a more compact, high-level representation \(h\) from the audio signals \(x\), that makes abstraction of local detail and has a much lower sample rate. Ideally, we would learn a model \(h = f(x)\) to extract such a representation from data (although using existing high-level representations like MIDI is also possible, as we’ll discuss later).</p>
<p>Then we can split up the task by training two separate models: a WaveNet that models the high-level representation: \(p_H(h)\), and another that models the local signal structure, conditioned on the high-level representation: \(p_{X \vert H}(x \vert h)\). The former model can focus on learning about long-range correlations, as local signal structure is not present in the representation it operates on. The latter model, on the other hand, can focus on learning about local signal structure, as relevant information about large-scale structure is readily available in its conditioning signal. Combined together, these models can be used to sample new audio signals by first sampling \(\hat{h} \sim p_H(h)\) and then \(\hat{x} \sim p_{X \vert H}(x \vert \hat{h})\).</p>
<p>We can learn both \(f(x)\) and \(p_{X \vert H}(x \vert h)\) together by training an <em>autoencoder</em>: \(f(x)\) is the encoder, a feed-forward neural network, and \(p_{X \vert H}(x \vert h)\) is the decoder, a conditional WaveNet. Learning these jointly will enable \(f(x)\) to adapt to the WaveNet, so that it extracts information that the WaveNet cannot easily model itself.</p>
<p>To make the subsequent modelling of \(h = f(x)\) with another WaveNet easier, we use a VQ-VAE<sup id="fnref:vqvae" role="doc-noteref"><a href="#fn:vqvae" class="footnote" rel="footnote">25</a></sup>: an <strong>autoencoder with a discrete bottleneck</strong>. This has two important consequences:</p>
<ul>
<li><strong>Autoregressive models seem to be more effective on discrete sequences</strong> than on continuous ones. Making the high-level representation discrete makes the hierarchical modelling task much easier, as we don’t need to adapt the WaveNet model to work with continuous data.</li>
<li>The discreteness of the representation also <strong>limits its information capacity</strong>, forcing the autoencoder to encode only the most important information in \(h\), and to use the autoregressive connections in the WaveNet decoder to capture any local structure that wasn’t encoded in \(h\).</li>
</ul>
<p>To split the task into more than two parts, we can apply this procedure again to the high-level representation \(h\) produced by the first application, and <strong>repeat this until we get a hierarchy with as many levels as desired</strong>. Higher levels in the hierarchy make abstraction of more and more of the low-level details of the signal, and have progressively lower sample rates (yielding shorter sequences). a three-level hierarchy is shown in the diagram below. Note that <strong>each level can be trained separately and in sequence</strong>, thus greatly reducing the computational requirements of training a model with a very large receptive field.</p>
<figure>
<img src="/images/hierarchical_wavenet.svg" alt="Hierarchical WaveNet model, consisting of (conditional) autoregressive models of several levels of learnt discrete representations." />
<figcaption>Hierarchical WaveNet model, consisting of (conditional) autoregressive models of several levels of learnt discrete representations.</figcaption>
</figure>
<p>My colleagues and I explored this idea and trained hierachical WaveNet models on piano music<sup id="fnref:challenge" role="doc-noteref"><a href="#fn:challenge" class="footnote" rel="footnote">26</a></sup>. We found that there was a trade-off between audio fidelity and long-range coherence of the generated samples. When more model capacity was repurposed to focus on long-range correlations, this reduced the capability of the model to capture local structure, resulting in lower perceived audio quality. We also conducted a human evaluation study where we asked several listeners to rate both the fidelity and the musicality of some generated samples, to demonstrate that hierarchical models produce samples which sound more musical.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Hierarchical WaveNet</strong>: <a href="https://papers.nips.cc/paper/8023-the-challenge-of-realistic-music-generation-modelling-raw-audio-at-scale">paper</a> - <a href="https://drive.google.com/drive/folders/1s7yGi928cMla8gZhfQKNXACPACSrJ9Vg">samples</a>
</p>
<h3 id="-wave2midi2wave-and-the-maestro-dataset"><a name="wave2midi2wave"></a> Wave2Midi2Wave and the MAESTRO dataset</h3>
<p>As alluded to earlier, rather than learning high-level representations of music audio from data, we could also <strong>use existing high-level representations such as MIDI</strong> to construct a hierarchical model. We can use a powerful language model to model music in the symbolic domain, and also construct a conditional WaveNet model that generates audio, given a MIDI representation. Together with my colleagues from the Magenta team at Google AI, <a href="https://magenta.tensorflow.org/maestro-wave2midi2wave">we trained such models</a> on a new dataset called MAESTRO, which features 172 hours of virtuosic piano performances, captured with fine alignment between note labels and audio waveforms<sup id="fnref:maestro" role="doc-noteref"><a href="#fn:maestro" class="footnote" rel="footnote">27</a></sup>. This dataset is <a href="https://magenta.tensorflow.org/datasets/maestro">available to download</a> for research purposes.</p>
<p>Compared to hierarchical WaveNets with learnt intermediate representations, this approach yields much better samples in terms of musical structure, but it is limited to instruments and styles of music that MIDI can accurately represent. Manzelli et al. <a href="http://people.bu.edu/bkulis/projects/music/index.html">have demonstrated this approach</a> for a few instruments other than piano<sup id="fnref:manzellithakkar" role="doc-noteref"><a href="#fn:manzellithakkar" class="footnote" rel="footnote">28</a></sup>, but the lack of available aligned data could pose a problem.</p>
<figure>
<img src="/images/wave2midi2wave.png" alt="Wave2Midi2Wave: a transcription model to go from audio to MIDI, a transformer to model MIDI sequences and a WaveNet to synthesise audio given a MIDI sequence." />
<figcaption>Wave2Midi2Wave: a transcription model to go from audio to MIDI, a transformer to model MIDI sequences and a WaveNet to synthesise audio given a MIDI sequence.</figcaption>
</figure>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Wave2Midi2Wave</strong>: <a href="https://openreview.net/forum?id=r1lYRjC9F7">paper</a> - <a href="https://magenta.tensorflow.org/maestro-wave2midi2wave">blog post</a> - <a href="https://storage.googleapis.com/magentadata/papers/maestro/index.html">samples</a> - <a href="https://magenta.tensorflow.org/datasets/maestro">dataset</a><br />
<strong>Manzelli et al. model</strong>: <a href="https://arxiv.org/abs/1806.09905">paper</a> - <a href="http://people.bu.edu/bkulis/projects/music/index.html">samples</a>
</p>
<h3 id="sparse-transformers">Sparse transformers</h3>
<p>OpenAI introduced the <a href="https://openai.com/blog/sparse-transformer/">Sparse Transformer</a> model<sup id="fnref:sparsetransformer" role="doc-noteref"><a href="#fn:sparsetransformer" class="footnote" rel="footnote">29</a></sup>, a large transformer<sup id="fnref:transformer" role="doc-noteref"><a href="#fn:transformer" class="footnote" rel="footnote">30</a></sup> with a <strong>sparse attention mechanism</strong> that scales better to long sequences than traditional attention (which is quadratic in the length of the modelled sequence). They demonstrated impressive results autoregressively modelling language, images, and music audio using this architecture, with sparse attention enabling their model to cope with waveforms of up to 65k timesteps (about 5 seconds at 12 kHz). The sparse attention mechanism seems like a good alternative to the stacked dilated convolutions of WaveNets, provided that an efficient implementation is available.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Sparse Transformer</strong>: <a href="https://arxiv.org/abs/1904.10509">paper</a> - <a href="https://openai.com/blog/sparse-transformer/">blog post</a> - <a href="https://soundcloud.com/openai_audio/sets/sparse_transformers">samples</a>
</p>
<h3 id="universal-music-translation-network">Universal music translation network</h3>
<p>An interesting conditional waveform modelling problem is that of “music translation” or “music style transfer”: given a waveform, <strong>render a new waveform where the same music is played by a different instrument</strong>. The Universal Music Translation Network<sup id="fnref:umtn" role="doc-noteref"><a href="#fn:umtn" class="footnote" rel="footnote">31</a></sup> tackles this by training an autoencoder with multiple WaveNet decoders, where the encoded representation is encouraged to be agnostic to the instrument of the input (using an adversarial loss). A separate decoder is trained for each target instrument, so once this representation is extracted from a waveform, it can be synthesised in an instrument of choice. The separation is not perfect, but it works surprisingly well in practice. I think this is a nice example of a model that combines ideas from both likelihood-based models and the adversarial learning paradigm.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Universal music translation network</strong>: <a href="https://openreview.net/forum?id=HJGkisCcKm">paper</a> - <a href="https://github.com/facebookresearch/music-translation">code</a> - <a href="https://musictranslation.github.io/">samples</a>
</p>
<h3 id="dadabots">Dadabots</h3>
<p><a href="http://dadabots.com">Dadabots</a> are a researcher / artist duo who have trained SampleRNN models on various albums (primarily metal) in order to produce more music in the same vein. These models aren’t great at capturing long-range correlations, so it works best for artists whose style is naturally a bit disjointed. Below is a 24 hour livestream they’ve set up with a model generating infinite technical death metal in the style of ‘Relentless Mutation’ by Archspire.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/MwtVkPKx3RA" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen=""></iframe>
<h2 id="-adversarial-models-of-waveforms"><a name="adversarial-models"></a> Adversarial models of waveforms</h2>
<p>Adversarial modelling of audio has only recently started to see some successes, which is why this section is going to be a lot shorter than the previous one on likelihood-based models. The adversarial paradigm has been extremely successful in the image domain, but researchers have had a harder time translating that success to other domains and modalities, compared to likelihood-based models. As a result, published work so far has primarily focused on speech generation and the generation of individual notes or very short clips of music. As a field, we are still very much in the process of figuring out how to make GANs work well for audio at scale.</p>
<h3 id="wavegan">WaveGAN</h3>
<p>One of the first works to attempt using GANs for modelling raw audio signals is WaveGAN<sup id="fnref:wavegan" role="doc-noteref"><a href="#fn:wavegan" class="footnote" rel="footnote">32</a></sup>. They trained a GAN on single-word speech recordings, bird vocalisations, individual drum hits and short excerpts of piano music. They also compared their raw audio-based model with a spectrogram-level model called SpecGAN. Although the fidelity of the <a href="https://chrisdonahue.com/wavegan_examples/">resulting samples</a> is far from perfect in some cases, this work undoubtedly inspired a lot of researchers to take audio modelling with GANs more seriously.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>WaveGAN</strong>: <a href="https://openreview.net/forum?id=ByMVTsR5KQ">paper</a> - <a href="https://github.com/chrisdonahue/wavegan">code</a> - <a href="https://chrisdonahue.com/wavegan_examples/">samples</a> - <a href="https://chrisdonahue.com/wavegan/">demo</a> - <a href="https://colab.research.google.com/drive/1e9o2NB2GDDjadptGr3rwQwTcw-IrFOnm">colab</a>
</p>
<h3 id="gansynth">GANSynth</h3>
<p>So far in this blog post, we have focused on generating audio waveforms directly. However, I don’t want to omit GANSynth<sup id="fnref:gansynth" role="doc-noteref"><a href="#fn:gansynth" class="footnote" rel="footnote">33</a></sup>, even though technically speaking it does not operate directly in the waveform domain. This is because the spectral representation it uses is <strong>exactly invertible</strong> – no other models or phase reconstruction algorithms are used to turn the spectograms it generates into waveforms, which means it shares a lot of the advantages of models that operate directly in the waveform domain.</p>
<p>As <a href="#why-waveforms">discussed before</a>, modelling the phase component of a complex spectrogram is challenging, because the phase of real audio signals can seem essentially random. However, using some of its unique characteristics, we can transform the phase into a quantity that is easier to model and reason about: the <em>instantaneous frequency</em>. This is obtained by computing the temporal difference of the <em>unwrapped</em> phase between subsequent frames. “Unwrapping” means that we shift the phase component by a multiple of \(2 \pi\) for each frame as needed to make it monotonic over time, as shown in the diagram below (because phase is an angle, all values modulo \(2 \pi\) are equivalent).</p>
<p><strong>The instantaneous frequency captures how much the phase of a signal moves from one spectrogram frame to the next</strong>. For harmonic sounds, this quantity is expected to be constant over time, as the phase rotates at a constant velocity. This makes this representation particularly suitable to model musical sounds, which have a lot of harmonic content (and in fact, it might also make the representation less suitable for modelling more general classes of audio signals, though I don’t know if anyone has tried). For harmonic sounds, the instantaneous frequency is almost trivial to predict.</p>
<p>GANSynth is an adversarial model trained to produce the magnitude and instantaneous frequency spectrograms of recordings of individual musical notes. The trained model is also able to generalise to sequences of notes to some degree. <a href="https://magenta.tensorflow.org/gansynth">Check out the blog post</a> for sound examples and more information.</p>
<figure>
<img src="/images/gansynth1.png" alt="Waveform with specrogram frame boundaries indicated as dotted lines." />
<img src="/images/gansynth2.png" alt="From phase to instantaneous frequency." />
<img src="/images/gansynth3.png" alt="Visualisations of the magnitude, phase, unwrapped phase and instantaneous frequency spectra of a real recording of a note." />
<figcaption><strong>Top</strong>: waveform with specrogram frame boundaries indicated as dotted lines. <strong>Middle</strong>: from phase to instantaneous frequency. <strong>Bottom</strong>: visualisations of the magnitude, phase, unwrapped phase and instantaneous frequency spectra of a real recording of a note.</figcaption>
</figure>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>GANSynth</strong>: <a href="https://openreview.net/forum?id=H1xQVn09FX">paper</a> - <a href="http://goo.gl/magenta/gansynth-code">code</a> - <a href="http://goo.gl/magenta/gansynth-examples">samples</a> - <a href="https://magenta.tensorflow.org/gansynth">blog post</a> - <a href="http://goo.gl/magenta/gansynth-demo">colab</a>
</p>
<h3 id="-melgan--gan-tts"><a name="melgan-gantts"></a> MelGAN & GAN-TTS</h3>
<p>Two recent papers demonstrate excellent results using GANs for text-to-speech: MelGAN<sup id="fnref:melgan" role="doc-noteref"><a href="#fn:melgan" class="footnote" rel="footnote">34</a></sup> and GAN-TTS<sup id="fnref:gantts" role="doc-noteref"><a href="#fn:gantts" class="footnote" rel="footnote">35</a></sup>. The former also includes some music synthesis results, although fidelity is still an issue in that domain. The focus of MelGAN is inversion of magnitude spectrograms (potentially generated by other models), whereas as GAN-TTS is conditioned on the same “linguistic features” as the original WaveNet for TTS.</p>
<p>The architectures of both models share some interesting similarities, which shed light on the right inductive biases for raw waveform discriminators. Both models use <strong>multiple discriminators at different scales</strong>, each of which operates on a <strong>random window</strong> of audio extracted from the full sequence produced by the generator. This is similar to the patch-based discriminators that have occasionally been used in GANs for image generation. This windowing strategy seems to dramatically improve the capability of the generator to <strong>correctly model high frequency content</strong> in the audio signals, which is much more crucial to get right for audio than for images because it more strongly affects perceptual quality. The fact that both models benefited from this particular discriminator design indicates that we may be on the way to figuring out how to best design discriminator architectures for raw audio.</p>
<p>There are also some interesting differences: where GAN-TTS uses a combination of conditional and unconditional discriminators, MelGAN uses only unconditional discriminators and instead encourages the generator output to match the ground truth audio by adding an additional <em>feature matching</em> loss: the L1 distance between discriminator feature maps of real and generated audio. Both approaches seem to be effective.</p>
<p>Adversarial waveform synthesis is particularly useful for TTS, because it enables the use of highly parallelisable feed-forward models, which tend to have relatively low capacity requirements because they are trained with a mode-seeking loss. This means the models <strong>can more easily be deployed on low-power hardware while still performing audio synthesis in real-time</strong>, compared to autoregressive or flow-based models.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>MelGAN</strong>: <a href="https://papers.nips.cc/paper/9629-melgan-generative-adversarial-networks-for-conditional-waveform-synthesis">paper</a> - <a href="https://github.com/descriptinc/melgan-neurips">code</a> - <a href="https://melgan-neurips.github.io/">samples</a><br />
<strong>GAN-TTS</strong>: <a href="https://openreview.net/forum?id=r1gfQgSFDr">paper</a> - <a href="https://github.com/mbinkowski/DeepSpeechDistances">code (FDSD)</a> - <a href="https://storage.googleapis.com/deepmind-media/research/abstract.wav">sample</a>
</p>
<h2 id="-discussion"><a name="discussion"></a> Discussion</h2>
<p>To wrap up this blog post, I want to summarise a few thoughts about the current state of this area of research, and where things could be moving next.</p>
<h3 id="why-the-emphasis-on-likelihood-in-music-modelling">Why the emphasis on likelihood in music modelling?</h3>
<p>Clearly, the dominant paradigm for generative models of music in the waveform domain is likelihood-based. This stands in stark contrast to the image domain, where adversarial approaches greatly outnumber likelihood-based ones. I suspect there are a few reasons for this (let me know if you think of any others):</p>
<ul>
<li>
<p>Compared to likelihood-based models, it seems like it has been harder to translate the successes of adversarial models in the image domain to other domains, and to the audio domain in particular. I think this is because in a GAN, the discriminator fulfills the role of a <strong>domain-specific loss function</strong>, and important prior knowledge that guides learning is encoded in its architecture. We have known about good architectural priors for images for a long time (stacks of convolutions), as evidenced by work on e.g. style transfer<sup id="fnref:styletransfer" role="doc-noteref"><a href="#fn:styletransfer" class="footnote" rel="footnote">36</a></sup> and the deep image prior<sup id="fnref:deepimageprior" role="doc-noteref"><a href="#fn:deepimageprior" class="footnote" rel="footnote">37</a></sup>. For other modalities, we don’t know as much yet. It seems we are now starting to figure out what kind of architectures work for waveforms (see <a href="#melgan-gantts">MelGAN and GAN-TTS</a>, some relevant work has also been done in the discriminative setting<sup id="fnref:randomcnn" role="doc-noteref"><a href="#fn:randomcnn" class="footnote" rel="footnote">38</a></sup>).</p>
</li>
<li>
<p><strong>Adversarial losses are mode-seeking</strong>, which makes them more suitable for settings where realism is more important than diversity (for example, because the conditioning signal contains most of the required diversity, as in TTS). In music generation, which is primarily a creative application, <strong>diversity is very important</strong>. Improving diversity of GAN samples is the subject of intense study right now, but I think it could be a while before they catch up with likelihood-based models in this sense.</p>
</li>
<li>
<p>The current disparity could also simply be a consequence of the fact that <strong>likelihood-based models got a head start</strong> in waveform modelling, with WaveNet and SampleRNN appearing on the scene in 2016 and WaveGAN in 2018.</p>
</li>
</ul>
<p>Another domain where likelihood-based models dominate is language modelling. I believe the underlying reasons for this might be a bit different though: language is inherently <strong>discrete</strong>, and extending GANs to modelling discrete data at scale is very much a work in progress. This is also more likely to be the reason why likelihood-based models are dominant for symbolic music generation as well: most symbolic representations of music are discrete.</p>
<h3 id="-alternatives-to-modelling-waveforms-directly"><a name="alternatives"></a> Alternatives to modelling waveforms directly</h3>
<p>Instead of modelling music in the waveform domain, there are many possible alternative approaches. We could model other representations of audio signals, such as spectrograms, as long as we have a way to obtain waveforms from such representations. We have quite a few options for this:</p>
<ul>
<li>
<p>We could use <strong>invertible spectrograms</strong> (i.e. phase information is not discarded), but in this case modelling the phase poses a considerable challenge. There are ways to make this easier, such as the instantaneous frequency representation used by GANSynth.</p>
</li>
<li>
<p>We could also use <strong>magnitude spectrograms</strong> (as is typically done in discriminative models of audio), and then use a <strong>phase reconstruction algorithm</strong> such as the Griffin-Lim algorithm<sup id="fnref:griffinlim" role="doc-noteref"><a href="#fn:griffinlim" class="footnote" rel="footnote">39</a></sup> to infer a plausible phase component, based only on the generated magnitude. This approach was used for the original Tacotron model for TTS<sup id="fnref:tacotron" role="doc-noteref"><a href="#fn:tacotron" class="footnote" rel="footnote">40</a></sup>, and for MelNet<sup id="fnref:melnet" role="doc-noteref"><a href="#fn:melnet" class="footnote" rel="footnote">41</a></sup>, which models music audio autoregressively in the spectrogram domain.</p>
</li>
<li>
<p>Instead of a traditional phase reconstruction algorithm, we could also use a <strong>vocoder</strong> to go from spectrograms to waveforms. A vocoder, in this context, is simply a generative model in the waveform domain, conditioned on spectrograms. Vocoding is a densely conditioned generation task, and many of the models discussed before can and have been used as vocoders (e.g. WaveNet in Tacotron 2<sup id="fnref:tacotron2" role="doc-noteref"><a href="#fn:tacotron2" class="footnote" rel="footnote">42</a></sup>, flow-based models of waveforms, or MelGAN). This approach has some advantages: generated magnitude spectrograms are often imperfect, and vocoder models can learn to account for these imperfections. Vocoders can also work with inherently lossy spectrogram representations such as mel-spectrograms and constant-Q spectrograms<sup id="fnref:constantq" role="doc-noteref"><a href="#fn:constantq" class="footnote" rel="footnote">43</a></sup>.</p>
</li>
<li>
<p>If we are generating audio conditioned on an existing audio signal, we could also simply <strong>reuse the phase</strong> of the input signal, rather than reconstructing or generating it. This is commonly done in source separation, and the approach could also be used for music style transfer.</p>
</li>
</ul>
<p>That said, modelling spectrograms <strong>isn’t always easier</strong> than modelling waveforms. Although spectrograms have a much lower temporal resolution, they contain much more information per timestep. In autoregressive models of spectrograms, one would have to condition along both the time and frequency axes to capture all dependencies, which means we end up with roughly as many sequential sampling steps as in the raw waveform case. This is the approach taken by MelNet.</p>
<p>An alternative is to make an <strong>assumption of independence between different frequency bands at each timestep</strong>, given previous timesteps. This enables autoregressive models to produce entire spectrogram frames at a time. This partial independence assumption turns out to be an acceptable compromise in the text-to-speech domain, and is used in Tacotron and Tacotron 2. Vocoder models are particularly useful here as they can attempt to fix the imperfections resulting from this simplification of the model. I’m not sure if anybody has tried, but I would suspect that this independence assumption would cause more problems for music generation.</p>
<p>An interesting new approach combining traditional signal processing ideas with neural networks is <a href="https://magenta.tensorflow.org/ddsp">Differentiable Digital Signal Processing (DDSP)</a><sup id="fnref:ddsp" role="doc-noteref"><a href="#fn:ddsp" class="footnote" rel="footnote">44</a></sup>. By creating learnable versions of existing DSP components and incorporating them directly into neural networks, these models are endowed with <strong>much stronger inductive biases about sound and music</strong>, and can learn to produce realistic audio with fewer trainable parameters, while also being more interpretable. I suspect that this research direction may gain a lot of traction in the near future, not in the least because the authors <a href="https://github.com/magenta/ddsp">have made their code publicly available</a>, and also because of its modularity and lower computational requirements.</p>
<figure>
<img src="/images/ddsp.png" alt="Diagram of an example DDSP model. The yellow boxes represent differentiable signal processing components." />
<figcaption>Diagram of an example DDSP model. The yellow boxes represent differentiable signal processing components. Taken from <a href="https://magenta.tensorflow.org/ddsp">the original blog post</a>.</figcaption>
</figure>
<p>Finally, we could train <strong>symbolic models of music</strong> instead: for many instruments, we already have realistic synthesisers, and we can even train them given enough data (see <a href="#wave2midi2wave">Wave2Midi2Wave</a>). If we are able to craft symbolic representations that capture the aspects of music we care about, then this is an attractive approach as it is much less computationally intensive. Magenta’s <a href="https://magenta.tensorflow.org/music-transformer">Music Transformer</a><sup id="fnref:musictransformer" role="doc-noteref"><a href="#fn:musictransformer" class="footnote" rel="footnote">45</a></sup> and OpenAI’s <a href="https://openai.com/blog/musenet/">MuseNet</a> are two models that have recently shown impressive results in this domain, and it is likely that other ideas from the language modelling community could bring further improvements.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>DDSP</strong>: <a href="https://openreview.net/forum?id=B1x1ma4tDr">paper</a> - <a href="https://github.com/magenta/ddsp">code</a> - <a href="https://g.co/magenta/ddsp-examples">samples</a> - <a href="https://magenta.tensorflow.org/ddsp">blog post</a> - <a href="https://g.co/magenta/ddsp-demo">colab</a><br />
<strong>Music Transformer</strong>: <a href="https://openreview.net/forum?id=rJe4ShAcF7">paper</a> - <a href="https://magenta.tensorflow.org/music-transformer">blog post</a><br />
<strong>MuseNet</strong>: <a href="https://openai.com/blog/musenet/">blog post</a>
</p>
<h3 id="whats-next">What’s next?</h3>
<p>Generative models of music in the waveform domain have seen substantial progress over the past few years, but the best results so far are still relatively easy to distinguish from real recordings, even at fairly short time scales. There is still a lot of room for improvement, but I believe a lot of this will be driven by better availability of computational resources, and not necessarily by radical innovation on the modelling front – we have great tools already, they are simply a bit expensive to use due to <strong>substantial computational requirements</strong>. As time goes on and computers get faster, hopefully this task will garner interest as it becomes accessible to more researchers.</p>
<p>One interesting question is <strong>whether adversarial models are going to catch up</strong> with likelihood-based models in this domain. I think it is quite likely that GANs, having recently made in-roads in the densely conditioned setting, will gradually be made to work for more sparsely conditioned audio generation tasks as well. Fully unconditional generation with long-term coherence seems very challenging however, and I suspect that the mode-seeking behaviour of the adversarial loss will make this much harder to achieve. A hybrid model, where a GAN captures local signal structure and another model with a different objective function captures high-level structure and long-term correlations, seems like a sensible thing to build.</p>
<p><strong>Hierarchy</strong> is a very important prior for music (and, come to think of it, for pretty much anything else we like to model), so models that explicitly incorporate this are going to have a leg up on models that don’t – at the cost of some additional complexity. Whether this additional complexity will always be worth it remains to be seen, but at the moment, this definitely seems to be the case.</p>
<p>At any rate, <strong>splitting up the problem into multiple stages</strong> that can be solved separately has been fruitful, and I think it will continue to be. So far, hierarchical models (with learnt or handcrafted intermediate representations) and spectrogram-based models with vocoders have worked well, but perhaps there are other ways to “divide and conquer”. A nice example of a different kind of split in the image domain is the one used in Subscale Pixel Networks<sup id="fnref:spn" role="doc-noteref"><a href="#fn:spn" class="footnote" rel="footnote">46</a></sup>, where separate networks model the most and least significant bits of the image data.</p>
<h2 id="-conclusion"><a name="conclusion"></a> Conclusion</h2>
<p>If you made it to the end of this post, congratulations! I hope I’ve convinced you that music modelling in the waveform domain is an interesting research problem. It is also <strong>very far from a solved problem</strong>, so there are lots of opportunities for interesting new work. I have probably missed a lot of relevant references, especially when it comes to more recent work. If you know about relevant work that isn’t discussed here, feel free to share it in the comments! Questions about this blog post and this line of research are very welcome as well.</p>
<!-- TODO: add some bolded parts to highlight them where it makes sense. -->
<h2 id="-references"><a name="references"></a> References</h2>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:folkrnn" role="doc-endnote">
<p>Sturm, Santos, Ben-Tal and Korshunova, “<a href="https://arxiv.org/pdf/1604.08723">Music transcription modelling and composition using deep learning</a>”, Proc. 1st Conf. Computer Simulation of Musical Creativity, Huddersfield, UK, July 2016. <a href="https://folkrnn.org/">folkrnn.org</a> <a href="#fnref:folkrnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:pixelrnn" role="doc-endnote">
<p>Van den Oord, Kalchbrenner and Kavukcuoglu, “<a href="https://arxiv.org/abs/1601.06759">Pixel recurrent neural networks</a>”, International Conference on Machine Learning, 2016. <a href="#fnref:pixelrnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:pixelcnn" role="doc-endnote">
<p>Van den Oord, Kalchbrenner, Espeholt, Vinyals and Graves, “<a href="http://papers.nips.cc/paper/6527-conditional-image-generation-with-pixelcnn-decoders">Conditional image generation with pixelcnn decoders</a>”, Advances in neural information processing systems 29 (NeurIPS), 2016. <a href="#fnref:pixelcnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:nice" role="doc-endnote">
<p>Dinh, Krueger and Bengio, “<a href="https://arxiv.org/abs/1410.8516">NICE: Non-linear Independent Components Estimation</a>”, arXiv, 2014. <a href="#fnref:nice" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:realnvp" role="doc-endnote">
<p>Dinh, Sohl-Dickstein and Bengio, “<a href="https://arxiv.org/abs/1605.08803">Density estimation using Real NVP</a>”, arXiv, 2016. <a href="#fnref:realnvp" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vaerezende" role="doc-endnote">
<p>Rezende, Mohamed and Wierstra, “<a href="https://arxiv.org/abs/1401.4082">Stochastic Backpropagation and Approximate Inference in Deep Generative Models</a>”, International Conference on Machine Learning, 2014. <a href="#fnref:vaerezende" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vaekingma" role="doc-endnote">
<p>Kingma and Welling, “<a href="https://arxiv.org/abs/1312.6114">Auto-Encoding Variational Bayes</a>”, International Conference on Learning Representations, 2014. <a href="#fnref:vaekingma" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:pc" role="doc-endnote">
<p>Bowman, Vilnis, Vinyals, Dai, Jozefowicz and Bengio, “<a href="https://arxiv.org/abs/1511.06349">Generating Sentences from a Continuous Space</a>”, 20th SIGNLL Conference on Computational Natural Language Learning, 2016. <a href="#fnref:pc" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gans" role="doc-endnote">
<p>Goodfellow, Pouget-Abadie, Mirza, Xu, Warde-Farley, Ozair, Courville and Bengio, “<a href="http://papers.nips.cc/paper/5423-generative-adversarial-nets">Generative Adversarial Nets</a>”, Advances in neural information processing systems 27 (NeurIPS), 2014. <a href="#fnref:gans" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:energy" role="doc-endnote">
<p>Du and Mordatch, “<a href="https://arxiv.org/abs/1903.08689">https://arxiv.org/abs/1903.08689</a>”, arXiv, 2019. <a href="#fnref:energy" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:wgan" role="doc-endnote">
<p>Arjovsky, Chintala and Bottou, “<a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a>”, arXiv, 2017. <a href="#fnref:wgan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:swa" role="doc-endnote">
<p>Kolouri, Pope, Martin and Rohde, “<a href="https://arxiv.org/abs/1804.01947">Sliced-Wasserstein Autoencoder: An Embarrassingly Simple Generative Model</a>”, arXiv, 2018. <a href="#fnref:swa" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:aiqn" role="doc-endnote">
<p>Ostrovski, Dabney and Munos, “<a href="https://arxiv.org/abs/1806.05575">Autoregressive Quantile Networks for Generative Modeling</a>”, International Conference on Machine Learning, 2018. <a href="#fnref:aiqn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:scorematching" role="doc-endnote">
<p>Hyvärinen, “<a href="http://www.jmlr.org/papers/v6/hyvarinen05a.html">Estimation of Non-Normalized Statistical Models by Score Matching</a>”, Journal of Machine Learning Research, 2005. <a href="#fnref:scorematching" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ssm" role="doc-endnote">
<p>Song, Garg, Shi and Ermon, “<a href="https://arxiv.org/abs/1905.07088">Sliced Score Matching: A Scalable Approach to Density and Score Estimation</a>”, UAI, 2019. <a href="#fnref:ssm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:scorebased" role="doc-endnote">
<p>Song and Ermon, “<a href="http://papers.nips.cc/paper/9361-generative-modeling-by-estimating-gradients-of-the-data-distribution">Generative Modeling by Estimating Gradients of the Data Distribution</a>”, Advances in neural information processing systems 32 (NeurIPS), 2019. <a href="#fnref:scorebased" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:wavenet" role="doc-endnote">
<p>Van den Oord, Dieleman, Zen, Simonyan, Vinyals, Graves, Kalchbrenner, Senior and Kavukcuoglu, “<a href="https://arxiv.org/abs/1609.03499">WaveNet: A Generative Model for Raw Audio</a>”, arXiv, 2016. <a href="#fnref:wavenet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:samplernn" role="doc-endnote">
<p>Mehri, Kumar, Gulrajani, Kumar, Jain, Sotelo, Courville and Bengio, “<a href="https://arxiv.org/abs/1612.07837">SampleRNN: An Unconditional End-to-End Neural Audio Generation Model</a>”, International Conference on Learning Representations, 2017. <a href="#fnref:samplernn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:parallelwavenet" role="doc-endnote">
<p>Van den Oord, Li, Babuschkin, Simonyan, Vinyals, Kavukcuoglu, van den Driessche, Lockhart, Cobo, Stimberg, Casagrande, Grewe, Noury, Dieleman, Elsen, Kalchbrenner, Zen, Graves, King, Walters, Belov and Hassabis, “<a href="https://arxiv.org/abs/1711.10433">Parallel WaveNet: Fast High-Fidelity Speech Synthesis</a>”, International Conference on Machine Learning, 2018. <a href="#fnref:parallelwavenet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:clarinet" role="doc-endnote">
<p>Ping, Peng and Chen, “<a href="https://arxiv.org/abs/1807.07281">ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:clarinet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:waveglow" role="doc-endnote">
<p>Prenger, Valle and Catanzaro, “<a href="https://arxiv.org/abs/1811.00002">WaveGlow: A Flow-based Generative Network for Speech Synthesis</a>”, International Conference on Acoustics, Speech, and Signal Procesing, 2019 <a href="#fnref:waveglow" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:flowavenet" role="doc-endnote">
<p>Kim, Lee, Song, Kim and Yoon, “<a href="https://arxiv.org/abs/1811.02155">FloWaveNet : A Generative Flow for Raw Audio</a>”, International Conference on Machine Learning, 2019. <a href="#fnref:flowavenet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:waveflow" role="doc-endnote">
<p>Ping, Peng, Zhao and Song, “<a href="https://arxiv.org/abs/1912.01219">WaveFlow: A Compact Flow-based Model for Raw Audio</a>”, ArXiv, 2019. <a href="#fnref:waveflow" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:blow" role="doc-endnote">
<p>Serrà, Pascual and Segura, “<a href="https://papers.nips.cc/paper/8904-blow-a-single-scale-hyperconditioned-flow-for-non-parallel-raw-audio-voice-conversion">Blow: a single-scale hyperconditioned flow for non-parallel raw-audio voice conversion</a>”, Advances in neural information processing systems 32 (NeurIPS), 2019. <a href="#fnref:blow" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqvae" role="doc-endnote">
<p>Van den Oord, Vinyals and Kavukcuoglu, “<a href="http://papers.nips.cc/paper/7210-neural-discrete-representation-learning">Neural Discrete Representation Learning</a>”, Advances in neural information processing systems 30 (NeurIPS), 2017. <a href="#fnref:vqvae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:challenge" role="doc-endnote">
<p>Dieleman, Van den Oord and Simonyan, “<a href="https://papers.nips.cc/paper/8023-the-challenge-of-realistic-music-generation-modelling-raw-audio-at-scale">The challenge of realistic music generation: modelling raw audio at scale</a>”, Advances in neural information processing systems 31 (NeurIPS), 2018. <a href="#fnref:challenge" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:maestro" role="doc-endnote">
<p>Hawthorne, Stasyuk, Roberts, Simon, Huang, Dieleman, Elsen, Engel and Eck, “<a href="https://openreview.net/forum?id=r1lYRjC9F7">Enabling Factorized Piano Music Modeling and Generation with the MAESTRO Dataset</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:maestro" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:manzellithakkar" role="doc-endnote">
<p>Manzelli, Thakkar, Siahkamari and Kulis, “<a href="https://arxiv.org/abs/1806.09905">Conditioning Deep Generative Raw Audio Models for Structured Automatic Music</a>”, International Society for Music Information Retrieval Conference, 2018. <a href="#fnref:manzellithakkar" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sparsetransformer" role="doc-endnote">
<p>Child, Gray, Radford and Sutskever, “<a href="https://arxiv.org/abs/1904.10509">Generating Long Sequences with Sparse Transformers</a>”, Arxiv, 2019. <a href="#fnref:sparsetransformer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:transformer" role="doc-endnote">
<p>Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser and Polosukhin, “<a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need">Attention is All you Need</a>”, Advances in neural information processing systems 30 (NeurIPS), 2017. <a href="#fnref:transformer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:umtn" role="doc-endnote">
<p>Mor, Wolf, Polyak and Taigman, “<a href="https://openreview.net/forum?id=HJGkisCcKm">A Universal Music Translation Network</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:umtn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:wavegan" role="doc-endnote">
<p>Donahue, McAuley and Puckette, “<a href="https://openreview.net/forum?id=ByMVTsR5KQ">Adversarial Audio Synthesis</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:wavegan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gansynth" role="doc-endnote">
<p>Engel, Agrawal, Chen, Gulrajani, Donahue and Roberts, “<a href="https://openreview.net/forum?id=H1xQVn09FX">GANSynth: Adversarial Neural Audio Synthesis</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:gansynth" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:melgan" role="doc-endnote">
<p>Kumar, Kumar, de Boissiere, Gestin, Teoh, Sotelo, de Brébisson, Bengio and Courville, “<a href="https://papers.nips.cc/paper/9629-melgan-generative-adversarial-networks-for-conditional-waveform-synthesis">MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis</a>”, Advances in neural information processing systems 32 (NeurIPS), 2019. <a href="#fnref:melgan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gantts" role="doc-endnote">
<p>Bińkowski, Donahue, Dieleman, Clark, Elsen, Casagrande, Cobo and Simonyan, “<a href="https://openreview.net/forum?id=r1gfQgSFDr">High Fidelity Speech Synthesis with Adversarial Networks</a>”, International Conference on Learning Representations, 2020. <a href="#fnref:gantts" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:styletransfer" role="doc-endnote">
<p>Gatys, Ecker and Bethge, “<a href="http://openaccess.thecvf.com/content_cvpr_2016/html/Gatys_Image_Style_Transfer_CVPR_2016_paper.html">Image Style Transfer Using Convolutional Neural Networks</a>”, IEEE Conference on Computer Vision and Pattern Recognition, 2016. <a href="#fnref:styletransfer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:deepimageprior" role="doc-endnote">
<p>Ulyanov, Vedaldi and Lempitsky, “<a href="http://openaccess.thecvf.com/content_cvpr_2018/html/Ulyanov_Deep_Image_Prior_CVPR_2018_paper.html">Deep Image Prior</a>”, IEEE Conference on Computer Vision and Pattern Recognition, 2018. <a href="#fnref:deepimageprior" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:randomcnn" role="doc-endnote">
<p>Pons and Serra, “<a href="https://arxiv.org/abs/1805.00237">Randomly weighted CNNs for (music) audio classification</a>”, IEEE International Conference on Acoustics, Speech and Signal Processing, 2019. <a href="#fnref:randomcnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:griffinlim" role="doc-endnote">
<p>Griffin and Lim, “<a href="https://ieeexplore.ieee.org/abstract/document/1164317/">Signal estimation from modified short-time Fourier transform</a>”, IEEE Transactions on Acoustics, Speech and Signal Processing, 1984. <a href="#fnref:griffinlim" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:tacotron" role="doc-endnote">
<p>Wang, Skerry-Ryan, Stanton, Wu, Weiss, Jaitly, Yang, Xiao, Chen, Bengio, Le, Agiomyrgiannakis, Clark and Saurous, “<a href="https://arxiv.org/abs/1703.10135">Tacotron: Towards end-to-end speech synthesis</a>”, Interspeech, 2017. <a href="#fnref:tacotron" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:melnet" role="doc-endnote">
<p>Vasquez and Lewis, “<a href="https://arxiv.org/abs/1906.01083">Melnet: A generative model for audio in the frequency domain</a>”, ArXiv, 2019. <a href="#fnref:melnet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:tacotron2" role="doc-endnote">
<p>Shen, Pang, Weiss, Schuster, Jaitly, Yang, Chen, Zhang, Wang, Skerry-Ryan, Saurous, Agiomyrgiannakis, Wu, “<a href="https://arxiv.org/abs/1712.05884">Natural TTS synthesis by conditioning wavenet on mel spectrogram predictions</a>”, IEEE International Conference on Acoustics, Speech and Signal Processing, 2018. <a href="#fnref:tacotron2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:constantq" role="doc-endnote">
<p>Schörkhuber and Klapuri, “<a href="https://iem.kug.ac.at/fileadmin/media/iem/projects/2010/smc10_schoerkhuber.pdf">Constant-Q transform toolbox for music processing</a>”, Sound and Music Computing Conference, 2010. <a href="#fnref:constantq" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ddsp" role="doc-endnote">
<p>Engel, Hantrakul, Gu and Roberts, “<a href="https://openreview.net/forum?id=B1x1ma4tDr">DDSP: Differentiable Digital Signal Processing</a>”, International Conference on Learning Representations, 2020. <a href="#fnref:ddsp" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:musictransformer" role="doc-endnote">
<p>Huang, Vaswani, Uszkoreit, Simon, Hawthorne, Shazeer, Dai, Hoffman, Dinculescu and Eck, “<a href="https://openreview.net/forum?id=rJe4ShAcF7">Music Transformer: Generating Music with Long-Term Structure </a>”, International Conference on Learning Representations, 2019. <a href="#fnref:musictransformer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:spn" role="doc-endnote">
<p>Menick and Kalchbrenner, “<a href="https://openreview.net/forum?id=HylzTiC5Km">Generating High Fidelity Images with Subscale Pixel Networks and Multidimensional Upscaling</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:spn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>In November last year, I co-presented a tutorial on waveform-based music processing with deep learning with Jordi Pons and Jongpil Lee at ISMIR 2019. Jongpil and Jordi talked about music classification and source separation respectively, and I presented the last part of the tutorial, on music generation in the waveform domain. It was very well received, so I’ve decided to write it up in the form of a blog post.New Lasagne feature: arbitrary expressions as layer parameters2015-11-10T00:00:00+00:002015-11-10T00:00:00+00:00https://sander.ai/2015/11/10/arbitrary-expressions-as-params<p style="background-color: #ffa; padding: 1.2em;">
This post is another collaboration with <a href="http://ofai.at/~jan.schlueter">Jan Schlüter from the OFAI</a> (<a href="https://github.com/f0k">@f0k</a> on GitHub), a fellow MIR researcher and one of the lead developers of <a href="http://lasagne.readthedocs.org/">Lasagne</a>. He recently added a cool new feature that we wanted to highlight: enabling the use of arbitrary Theano expressions as layer parameters.
</p>
<p>As many of you probably know, Jan Schlüter and I are part of the team that develops <a href="http://lasagne.readthedocs.org/">Lasagne</a>, a lightweight neural network library built on top of <a href="http://deeplearning.net/software/theano/">Theano</a>.</p>
<p>One of the key <a href="http://lasagne.readthedocs.org/en/latest/user/development.html#philosophy">design principles</a> of Lasagne is <em>transparency</em>: we try not to hide Theano or numpy behind an additional layer of abstractions and encapsulation, but rather expose their functionality and data types and try to follow their conventions. This makes it very easy to learn how to use Lasagne if you already know how to use Theano – there just isn’t all that much extra to learn. But most importantly, it allows you to easily mix and match parts of Lasagne with vanilla Theano code. This is the way Lasagne is meant to be used.</p>
<p>In keeping with this philosophy, Jan recently added a feature that we’ve been discussing early on in designing the API (<a href="https://github.com/Lasagne/Lasagne/issues/11">#11</a>): it allows any learnable layer parameter to be specified as a mathematical expression evaluating to a correctly-shaped tensor. Previously, layer parameters had to be Theano shared variables, i.e., naked tensors to be learned directly. <strong>This new feature makes it possible to constrain network parameters in various, potentially creative ways.</strong> Below, we’ll go through a few examples of what is now possible that wasn’t before.</p>
<h2 id="default-case">Default case</h2>
<p>Let’s create a simple fully-connected layer of 500 units on top of an input layer of 784 units.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">lasagne.layers</span> <span class="kn">import</span> <span class="n">InputLayer</span><span class="p">,</span> <span class="n">DenseLayer</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">l1</span> <span class="o">=</span> <span class="n">InputLayer</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">784</span><span class="p">))</span>
<span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span></code></pre></figure>
<h2 id="autoencoder-with-tied-weights">Autoencoder with tied weights</h2>
<p>Autoencoders with tied weights are a common use case, and until now implementing them in Lasagne was a bit tricky. Weight sharing in Lasagne has always been easy and intuitive:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="n">l3</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">l2</span><span class="p">.</span><span class="n">W</span><span class="p">)</span>
<span class="c1"># l2 and l3 now share the same weight matrix!</span></code></pre></figure>
<p>… but in an autoencoder, you want the weights of the decoding layer to be the <em>transpose</em> of the weights of the encoding layer. So you would do:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="n">l3</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l2</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">784</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">l2</span><span class="p">.</span><span class="n">W</span><span class="p">.</span><span class="n">T</span><span class="p">)</span></code></pre></figure>
<p>… but that didn’t work before: <code class="language-plaintext highlighter-rouge">l2.W.T</code> is a Theano expression, but not a Theano shared variable as was expected. This is counter-intuitive, and indeed, <a href="https://groups.google.com/forum/#!searchin/lasagne-users/tied$20weights/lasagne-users/ky78GBSgnBI/z10Br4p4kHMJ">people expected it to work</a> and were disappointed to find out that it didn’t. With the new feature this is no longer true. The above will work just fine. Yay!</p>
<h2 id="factorized-weights">Factorized weights</h2>
<p>To reduce the number of parameters in your network (e.g. to prevent overfitting), you could force large parameter matrices to be <em>low-rank</em> by factorizing them. In our example from before, we could factorize the 784x500 weight matrix into the product of a 784x100 and a 100x500 matrix. The number of weights of the layer then goes down from 392000 to 128400 (not including the biases).</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">theano</span>
<span class="kn">import</span> <span class="nn">theano.tensor</span> <span class="k">as</span> <span class="n">T</span>
<span class="kn">from</span> <span class="nn">lasagne.init</span> <span class="kn">import</span> <span class="n">GlorotUniform</span>
<span class="kn">from</span> <span class="nn">lasagne.utils</span> <span class="kn">import</span> <span class="n">floatX</span>
<span class="n">w_init</span> <span class="o">=</span> <span class="n">GlorotUniform</span><span class="p">()</span>
<span class="n">w1</span> <span class="o">=</span> <span class="n">theano</span><span class="p">.</span><span class="n">shared</span><span class="p">(</span><span class="n">floatX</span><span class="p">(</span><span class="n">w_init</span><span class="p">((</span><span class="mi">784</span><span class="p">,</span> <span class="mi">100</span><span class="p">))))</span>
<span class="n">w2</span> <span class="o">=</span> <span class="n">theano</span><span class="p">.</span><span class="n">shared</span><span class="p">(</span><span class="n">floatX</span><span class="p">(</span><span class="n">w_init</span><span class="p">((</span><span class="mi">100</span><span class="p">,</span> <span class="mi">500</span><span class="p">))))</span>
<span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">T</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">))</span></code></pre></figure>
<p>Granted, this was possible before by inserting a biasless linear layer:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">l2_a</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l2_a</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span></code></pre></figure>
<p>Other types of factorizations <a href="http://arxiv.org/abs/1509.06569">may also be worth investigating!</a></p>
<h2 id="positive-weights">Positive weights</h2>
<p>If you want to force the weights of a layer to be positive, you can learn their logarithm:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">lasagne.init</span> <span class="kn">import</span> <span class="n">Normal</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">theano</span><span class="p">.</span><span class="n">shared</span><span class="p">(</span><span class="n">floatX</span><span class="p">(</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">mean</span><span class="o">=-</span><span class="mi">10</span><span class="p">)((</span><span class="mi">784</span><span class="p">,</span> <span class="mi">500</span><span class="p">))))</span>
<span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">T</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">w</span><span class="p">))</span></code></pre></figure>
<p>You could also use <code class="language-plaintext highlighter-rouge">T.softplus(w)</code> instead of <code class="language-plaintext highlighter-rouge">T.exp(w)</code>. You might also be tempted to try sticking a ReLU in there (<code class="language-plaintext highlighter-rouge">T.maximum(w, 0)</code>), but note that applying the linear rectifier to the weight matrix would lead to many of the underlying weights getting stuck at negative values, as the linear rectifier has zero gradient for negative inputs!</p>
<h2 id="positive-semi-definite-weights">Positive semi-definite weights</h2>
<p>There are plenty of other creative uses, such as constraining weights to be positive semi-definite (for whatever reason):</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">theano</span><span class="p">.</span><span class="n">shared</span><span class="p">(</span><span class="n">floatX</span><span class="p">(</span><span class="n">w_init</span><span class="p">((</span><span class="mi">500</span><span class="p">,</span> <span class="mi">500</span><span class="p">))))</span>
<span class="n">w_psd</span> <span class="o">=</span> <span class="n">T</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">w</span><span class="p">.</span><span class="n">T</span><span class="p">)</span>
<span class="n">l3</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l2</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">w_psd</span><span class="p">)</span></code></pre></figure>
<h2 id="limitations">Limitations</h2>
<p>There are only a couple of limitations to using Theano expressions as layer parameters. One is that Lasagne functions and methods such as <code class="language-plaintext highlighter-rouge">Layer.get_params()</code> will implicitly assume that any shared variable featuring in these Theano expressions is to be treated as a parameter. In practice that means you can’t mix learnable and non-learnable parameter variables in a single expression. Also, the same tags will apply to all shared variables in an expression. More information about parameter tags can be found in <a href="http://lasagne.readthedocs.org/en/latest/modules/layers/base.html#lasagne.layers.Layer.get_params">the documentation</a>.</p>
<p>For almost all use cases, these limitations should not be an issue. If they are, your best bet is to implement a custom layer class. Luckily, <a href="http://lasagne.readthedocs.org/en/latest/user/custom_layers.html">this is also very easy in Lasagne</a>.</p>
<h2 id="why-it-works">Why it works</h2>
<p>All of this is made possible because Lasagne builds on Theano, which takes care of backpropagating through the parameter expression to any underlying learned tensors. In frameworks building on hard-coded layer implementations rather than an automatic expression compiler, all these examples would require writing custom backpropagation code.</p>
<p>If you want to play around with this yourself, try the bleeding-edge version of Lasagne. You can find <a href="http://lasagne.readthedocs.org/en/latest/user/installation.html#bleeding-edge-version">installation instructions here</a>.</p>
<p><strong>Have fun experimenting!</strong> If you’ve done something cool that you’d like to share, feel free to send us a pull request on our <a href="https://github.com/Lasagne/Recipes">Recipes repository</a>.</p>This post is another collaboration with Jan Schlüter from the OFAI (@f0k on GitHub), a fellow MIR researcher and one of the lead developers of Lasagne. He recently added a cool new feature that we wanted to highlight: enabling the use of arbitrary Theano expressions as layer parameters.Paper about my Galaxy Challenge solution2015-03-25T00:00:00+00:002015-03-25T00:00:00+00:00https://sander.ai/2015/03/25/gz-paper<p><strong>UPDATE</strong> (April 27th): the paper is now available on the journal website: <a href="http://mnras.oxfordjournals.org/content/450/2/1441">http://mnras.oxfordjournals.org/content/450/2/1441</a></p>
<p>Together with Kyle Willett, one of the organizers of the <a href="http://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge">Galaxy Challenge</a>, I’ve written a paper about my winning solution for this competition. It is <a href="http://arxiv.org/abs/1503.07077">available on ArXiv</a>.</p>
<p>The paper has been accepted for publication in <a href="http://mnras.oxfordjournals.org/">MNRAS</a>, a journal on astronomy and astrophysics, but is also aimed at people with a machine learning background. Due to this dual audience, it contains both an in-depth overview of deep learning and convolutional networks, and a thorough analysis of the resulting model and its potential impact for astronomy research.</p>
<p>There is some overlap with <a href="http://benanne.github.io/2014/04/05/galaxy-zoo.html">the blog post</a> I wrote after the competition ended, but there is a lot more detail and background information, and the ‘results’ and ‘analysis’ sections are entirely new (although those of you who have seen one of my talks on the subject may have seen some of the images before).</p>
<p>I am very grateful to Kyle Willett for helping me write the manuscript. Without his help, writing a paper for an audience of astronomers would have been an impossible task for me. I believe it’s crucially important that applications of deep learning and machine learning in general get communicated to the people that could benefit from them, in such a way that they might actually consider using them.</p>
<p>I am also grateful to current and former supervisors, Joni Dambre and Benjamin Schrauwen, for supporting me when I was working on this competition and this paper, even though it is only tangentially related to the subject of my PhD.</p>
<p>Original arxiv link: <a href="http://arxiv.org/abs/1503.07077">http://arxiv.org/abs/1503.07077</a></p>UPDATE (April 27th): the paper is now available on the journal website: http://mnras.oxfordjournals.org/content/450/2/1441