Part VII - A Dive Into DDPMs & CUDA kernel for Denoising
A concise summary of the Denoising Diffusion Probabilstic Models and implementation of a CUDA kernel to eliminate dependencies on the Huggingface Diffusers library.
Published
Feb 18th, 2024

The crux of the paper we are optimizing [1] is the insight that denoising diffusion models perform particularly well when applied to robotics applications. Real-time inference is important for such applications and in part what inspired me to pick this topic to dive into GPU programming with. Diffusion models have gotten insanely good at image generation, and can generalize to surprisingly abstract prompts. I hope this improvement is possible with this approach in the field of low-level robotics control as well! The early results in Cheng et al. are certainly very promising.

The original implementation for the reverse denoising steps for the Diffusion Policy paper used the Huggingface Diffuser’s library. In order to get rid of some repeated computation and ensure the denoising merges cleanly with our CUDA graph, I wrote a custom kernel to replace the functionality the HF Diffusion scheduler was providing. Before we get into all that though, lets walk over the HF Diffuser code & discuss some of the math behind DDPMs.

Diffusion models [4] are based on a relatively simple idea: iteratively turn a bunch of real data into noise and then train a neural network to predict the noise added during each step of this iterative process. We can then use the trained model to turn random noise into structured data similar to the kind the model was trained on. The hope here would be that in order to learn how to tell noise apart from signal, the model learns useful regularities and structure in the original data distribution. These models seem to perform particularly well when trained on high-dimensional, continuous data such as images or robot action trajectories. I have seen some papers applying this idea to language modelling, but to my understanding standard auto-regressive decoding still performs best for LLMs.

Tangent

The original idea came from Jascha Sohl-Dickstein in 2015 [2], who was inspired by non-equlibrium statistical physics that describes disorderly particle diffusion. A physical analogy for the diffusion process is that of food-die diffusing in water. When the die is initially introduced, the distribution of die-particles in the water is quite complex and difficult to describe. As the die diffuses throughout the body it is in the distribution converges to something much simpler. The end-state would be the isotropic Gaussian distribution we sample from, and the neural network learns how to iteratively go back to the initial complex distribution. A recent paper (Cold Diffusion [3]) shows neural nets can actually learn entirely deterministic transitions using non-Gaussian noise (albeit with lower performance) which calls into question some of the theory behind diffusion models. I am not sure what the actual degree of congruency in these two processes (image generation via diffusion vs. stochastic particle diffusion) is, but I sure hope the two are deeply linked. Its a pretty beautiful notion!!

Forward Noising Process

The forward noising process \((q_{X_t} | q_{X_{t-1}})\) is defined as a Markov chain which adds Gaussian noise to the data according to a variance schedule \(B_1, \ldots, B_t\). Each \(x_t\) is sampled from a conditional Gaussian distribution with mean \(\sqrt{1 - \beta_t}x_{t-1}\), and variance of \(\beta_t\). The distribution is given by:
\[q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_{t}} x_{t-1}, \beta_t I).\]
We perform this sampling by pulling a noise value out of a standard Gaussian distribution, multiplying by \(\sqrt{\beta_t}\), and adding it to the sample from the previous step multiplied by \(\sqrt{1-\beta_t}\):
\[x_t = \sqrt{1 - \beta_t} x_{t-1} + \sqrt{\beta_t} \epsilon.\]
As \(\beta\) approaches 1 with each timestep, what we are effectively doing is scaling the original data towards nothing and replacing it with standard Gaussian noise. The end result, assuming the variance schedule and number of timesteps were chosen appropriately, is an isotropic Gaussian distribution. A sample image going through the noising process with a cosine schedule is shown below.

Due to the Markov property of the forward process, we can actually sample at any arbitrary timestep of the forward process by expressing the scaling factor for the mean and variance as a product from timestep 1 to t.

\begin{align*}
\alpha_t &:= 1 - \beta_t, \\
\bar{\alpha}_t &:= \prod_{s=1}^{t} \alpha_s, \\
q(x_t|x_0) &= \mathcal{N} \left( \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I \right).
\end{align*}

Training

To train our U-Net, we can use the above property to sample from q(x_t|x_o) at a random timestep. We plug the resulting x_t and timestep into E_theta to get the model’s estimate for the epsilon in that sample. A straightforward MSE loss between the predicted noise and the actual noise is used to provide gradients with which we can update the parameters of the model to better predict the actual noise.

\[
\nabla_{\theta} \left\| \epsilon - \epsilon_{\theta}\left(\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon_t\right) \right\|^2
\]

Reverse Process

During inference, we generate noisy actions from a standard Gaussian distribution and use the U-Net to iteratively de-noise the action over 100 timesteps. The transition from timestep t to t-1 is described by the two equations below.

\begin{align*}
z &\sim
\begin{cases}
\mathcal{N}(0, I) & \text{if } t > 1, \\
0 & \text{else}
\end{cases} \\
x_{t-1} &= \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_{\theta}(x_t, t) \right) + \sigma_t z
\end{align*}

What we are doing above is using our U-Net to make a prediction for the noise in the sample for the current timestep, scaling it by \( \frac{B}{\sqrt{1-\alpha_{\text{cum}}}} \) and subtracting that value from the original sample. We then scale that by \( \frac{1}{\sqrt{\alpha_t}} \) and add to it some more randomly sampled Gaussian noise with variance of \( \sqrt{\beta} \). The high-level idea here is to take our prediction for the noise in the image, scale it according to the variance schedule, subtract it from the original sample, and add some stochasticity to the process. The HF implementation uses a different parameterization of the reverse process to provide more flexibility in the output. This approach first uses the model’s noise estimate to predict \( x_0 \) (the original denoised action).

\[
x_0 \approx \hat{x}_0 = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon_{\theta}(x_t, t)}{\sqrt{\bar{\alpha}_t}}
\]

Then, x_0 is plugged into the equation below which expresses the previous timestep’s sample as a scaled sum of the original action and the current timesteps action.

\[
\mu_t(x_t, x_0) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t(1 - \bar{\alpha}_{t-1})}}{1 - \bar{\alpha}_t} x_t, \quad \text{and} \quad \beta_t := \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t}
\]

Optimizations

The two optimizations we’ll make in regards to the diffusion process is to

  1. Pre-compute the constants associated with the denoising step when the model is initialized, send those tensors to the GPU, and use them in a custom kernel the implements the lighter-weight parameterization to obtain x_t-1 from x_t and timestep.
  2. Pre-compute the timestep embeddings when the model is initialized and reuse them throughout the inference process.

To be transparent, the optimizations outlined above aren’t budging the overall inference latency a whole lot. The denoising step just doesn’t take that many FLOPs, but I felt it was worth fusing all this into a kernel in order to

a) use the lighter-weight expression of the reverse process

b) ensure the denoising step would fuse cleanly into the rest of the CUDA graph

c) prevent constantly recomputing the scaling constants/timestep embeddings at every reverse step

Pre-Computing Timestep Embeddings

The current timestep during the forward/reverse process is inputted to the U-Net as a sinusodial positional embedding that’s passed through a linear layer, mish activation, and another linear layer. In order to avoid going through these layers at every forward pass, we add logic to the init method of the ConditionalUNet1D class:


self.timestep_embeddings = [self.diffusion_step_encoder(torch.tensor([timestep])).to('cuda') for timestep in range(denoising_steps)]

and calling into the array of pre-computed embeddings in the forward method:


global_feature = self.timestep_embeddings[timestep]

I was surprised to see that this actually reduces the performance of the forward pass by ~5% and ended up ditching this in the final library. I think its because the added Python indexing logic takes longer than just running these very small kernels inside a CUDA graph. If we wanted to get this speed-up we would have to pass a pointer to the tensor into the kernel and have the kernel itself index the appropriate embedding (similar to what we do below). I didn’t feel it worth diving into that given the small size of the embeddings layer.

Pre-Computing Denoising Constants and Custom Denoising Kernel

When initializing our model, we call a helper function to create the relevant constants we'll need in the denoising kernel, and send a stacked tensor with these to our GPU. By doing this we avoid recomputing them every time the denoising operation is ran.


def generate_diffusion_constants():
    noise_scheduler = DDPMScheduler(
        num_train_timesteps=config.NUM_DIFFUSION_ITERS,
        beta_schedule='squaredcos_cap_v2',
        clip_sample=True,
        prediction_type='epsilon'
    )

    betas = noise_scheduler.betas
    alphas_cumprod = noise_scheduler.alphas_cumprod
    betas_cumprod = 1 - alphas_cumprod
    noise_scaling_factor = betas / (betas_cumprod ** 0.5)

    alphas = noise_scheduler.alphas
    prev_sample_scaling_factor = 1/(alphas ** 0.5)
    
    variance = betas**0.5

    stacked_constants = torch.stack([noise_scaling_factor, prev_sample_scaling_factor, variance], dim=0)

    diffusion_constants = stacked_constants.transpose(0, 1).contiguous().reshape(-1).to('cuda')

    return diffusion_constants


Regarding the stacking and transposing of the constants (shown above), this transformation cleans up the tensor layout for loading from DRAM. The final tensor lines up the two constants for each timestep next to each other so we don’t have to make strided loads. Implementing the denoising step in CUDA is pretty straightforward since our model’s outputs are of shape (2, 16), so we can get away with a single warp where each of the 32 threads denoises a single element.

The kernel below is just implementing the denoising equation shown in the explanation of the reverse process above. We make use of constant memory to hold the diffusion constants since every thread will need to access the same values.

\begin{align*}
z &\sim
\begin{cases}
\mathcal{N}(0, I) & \text{if } t > 1, \\
0 & \text{else}
\end{cases} \\
x_{t-1} &= \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_{\theta}(x_t, t) \right) + \sigma_t z
\end{align*}


__constant__ float const_diffusion_constants[3*100]; // Assuming a maximum of 100 timesteps

__global__ void denoise(
    float* model_output, 
    float* sample, 
    long* timestep, 
    float* diffusion_noise,
    float* out
) {
    int tid = threadIdx.x;
    int constants_offset = (*timestep) * 3;

    float weighted_noise = model_output[tid] * const_diffusion_constants[constants_offset];
    float weighted_sample = (sample[tid] - weighted_noise) * const_diffusion_constants[constants_offset + 1];
    weighted_sample = fmax(-1.0, fmin(1.0, weighted_sample));
    if (*timestep != 0)
        weighted_sample += diffusion_noise[tid] * const_diffusion_constants[constants_offset + 2];

    out[tid] = weighted_sample;
}

With this change we can ditch the HF Diffusers dependency and directly return values from the forward pass that are the denoised outputs!

Feb 28th Update - I realized when unit-testing the denoising kernel that the condensed parameterization of the denoising operation seems to deviate from the one that the HF diffusers library implements. The impact on the overall end-to-end performance is minimal but I ended up re-writing the denoising kernel to use the same parameterization as the HF library. The high level concepts are the same, but I wanted to put this note in to explain why the final denoising kernel doesn't match what I have above.